Commit 7e06e156 authored by Jonathan Rudenberg's avatar Jonathan Rudenberg

Ensure that the handshake state doesn’t get lost if decryption fails

This allows decoding potentially malicious messages under certain
circumstances.
Signed-off-by: default avatarJonathan Rudenberg <jonathan@titanous.com>
parent bc8ae75e
......@@ -312,3 +312,32 @@ func (NoiseSuite) TestPSK_XX(c *C) {
expected, _ := hex.DecodeString("2b9c628158a517e3984dc619245d4b9cd73561944f266181b183812ca73499881e30f6e7eeb576c258acc713c2c62874fd1beb76b122f6303f974109aefd7e2a")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestHandshakeRollback(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
hsI := NewHandshakeState(Config{CipherSuite: cs, Random: rngI, Pattern: HandshakeNN, Initiator: true})
hsR := NewHandshakeState(Config{CipherSuite: cs, Random: rngR, Pattern: HandshakeNN, Initiator: false})
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 35)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "abc")
msg, _, _ = hsR.WriteMessage(nil, []byte("defg"))
c.Assert(msg, HasLen, 52)
prev := msg[1]
msg[1] = msg[1] + 1
_, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(err, Not(IsNil))
msg[1] = prev
res, _, _, err = hsI.ReadMessage(nil, msg)
c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c5e4dc9545d41b3280f4586a5481829e1e24ec5a0")
c.Assert(msg, DeepEquals, expected)
}
......@@ -66,6 +66,9 @@ type symmetricState struct {
hasPSK bool
ck []byte
h []byte
prevCK []byte
prevH []byte
}
func (s *symmetricState) InitializeSymmetric(handshakeName []byte) {
......@@ -137,6 +140,27 @@ func (s *symmetricState) Split() (*CipherState, *CipherState) {
return s1, s2
}
func (s *symmetricState) Checkpoint() {
if len(s.ck) > cap(s.prevCK) {
s.prevCK = make([]byte, len(s.ck))
}
s.prevCK = s.prevCK[:len(s.ck)]
copy(s.prevCK, s.ck)
if len(s.h) > cap(s.prevH) {
s.prevH = make([]byte, len(s.h))
}
s.prevH = s.prevH[:len(s.h)]
copy(s.prevH, s.h)
}
func (s *symmetricState) Rollback() {
s.ck = s.ck[:len(s.prevCK)]
copy(s.ck, s.prevCK)
s.h = s.h[:len(s.prevH)]
copy(s.h, s.prevH)
}
// A MessagePattern is a single message or operation used in a Noise handshake.
type MessagePattern int
......@@ -340,6 +364,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
panic("noise: no handshake messages left")
}
s.ss.Checkpoint()
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
......@@ -369,6 +395,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
}
if err != nil {
s.ss.Rollback()
return nil, nil, nil, err
}
message = message[expected:]
......@@ -382,12 +409,13 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
}
}
s.shouldWrite = true
s.msgIdx++
out, err = s.ss.DecryptAndHash(out, message)
if err != nil {
s.ss.Rollback()
return nil, nil, nil, err
}
s.shouldWrite = true
s.msgIdx++
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment