Commit 0e9c47ad authored by Jonathan Rudenberg's avatar Jonathan Rudenberg

Implement NoisePSK revision 2

parent f7b9b283
......@@ -27,7 +27,7 @@ func (NoiseSuite) TestN(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rng := new(RandomInc)
staticR := cs.GenerateKeypair(rng)
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, staticR.Public, nil)
hs := NewHandshakeState(cs, rng, HandshakeN, true, nil, nil, nil, nil, staticR.Public, nil)
hello, _, _ := hs.WriteMessage(nil, nil)
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662548331a3d1e93b490263abc7a4633867f4")
......@@ -39,7 +39,7 @@ func (NoiseSuite) TestX(c *C) {
rng := new(RandomInc)
staticI := cs.GenerateKeypair(rng)
staticR := cs.GenerateKeypair(rng)
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, &staticI, nil, staticR.Public, nil)
hs := NewHandshakeState(cs, rng, HandshakeX, true, nil, nil, &staticI, nil, staticR.Public, nil)
hello, _, _ := hs.WriteMessage(nil, nil)
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad203cd28d81cf65a2da637f557a05728b3ae4abdc3a42d1cda5f719d6cf41d7f2cf1b1c5af10e38a09a9bb7e3b1d589a99492cc50293eaa1f3f391b59bb6990d")
......@@ -52,8 +52,8 @@ func (NoiseSuite) TestNN(c *C) {
rngR := new(RandomInc)
*rngR = 1
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil)
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, nil, nil, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeNN, false, nil, nil, nil, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 35)
......@@ -80,8 +80,8 @@ func (NoiseSuite) TestXX(c *C) {
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 35)
......@@ -114,8 +114,8 @@ func (NoiseSuite) TestIK(c *C) {
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), &staticI, nil, staticR.Public, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), &staticR, nil, nil, nil)
hsI := NewHandshakeState(cs, rngI, HandshakeIK, true, []byte("ABC"), nil, &staticI, nil, staticR.Public, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeIK, false, []byte("ABC"), nil, &staticR, nil, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 99)
......@@ -143,8 +143,8 @@ func (NoiseSuite) TestXE(c *C) {
staticR := cs.GenerateKeypair(rngR)
ephR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, &staticI, nil, staticR.Public, ephR.Public)
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, &staticR, &ephR, nil, nil)
hsI := NewHandshakeState(cs, rngI, HandshakeXE, true, nil, nil, &staticI, nil, staticR.Public, ephR.Public)
hsR := NewHandshakeState(cs, rngR, HandshakeXE, false, nil, nil, &staticR, &ephR, nil, nil)
msg, _, _ := hsI.WriteMessage(nil, []byte("abc"))
c.Assert(msg, HasLen, 51)
......@@ -177,8 +177,8 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, &staticR, nil, nil, nil)
hsI := NewHandshakeState(cs, rngI, HandshakeXX, true, nil, nil, &staticI, nil, nil, nil)
hsR := NewHandshakeState(cs, rngR, HandshakeXX, false, nil, nil, &staticR, nil, nil, nil)
// -> e
msg, _, _ := hsI.WriteMessage(nil, []byte("abcdef"))
......@@ -220,3 +220,39 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "worri")
}
func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
hsI := NewHandshakeState(cs, rngI, HandshakeNN, true, nil, []byte("supersecret"), nil, nil, nil, nil)
hsR := NewHandshakeState(cs, rngI, HandshakeNN, false, nil, []byte("supersecret"), nil, nil, nil, nil)
// -> e
msg, _, _ := hsI.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 48)
res, _, _, err := hsR.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
// <- e, dhee
msg, csR0, csR1 := hsR.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 48)
res, csI0, csI1, err := hsI.ReadMessage(nil, msg)
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
// transport I -> R
msg = csI0.Encrypt(nil, nil, []byte("foo"))
res, err = csR0.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "foo")
// transport R -> I
msg = csR1.Encrypt(nil, nil, []byte("bar"))
res, err = csI1.Decrypt(nil, nil, msg)
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "bar")
}
......@@ -84,7 +84,7 @@ var HandshakeXN = HandshakePattern{
var HandshakeIN = HandshakePattern{
Name: "IN",
Messages: [][]MessagePattern{
{MessagePatternS, MessagePatternE},
{MessagePatternE, MessagePatternS},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES},
},
}
......@@ -139,7 +139,7 @@ var HandshakeXX = HandshakePattern{
var HandshakeIX = HandshakePattern{
Name: "IX",
Messages: [][]MessagePattern{
{MessagePatternS, MessagePatternE},
{MessagePatternE, MessagePatternS},
{MessagePatternE, MessagePatternDHEE, MessagePatternDHES, MessagePatternS, MessagePatternDHSE},
},
}
......
......@@ -119,16 +119,18 @@ type HandshakeState struct {
e DHKey // local ephemeral keypair
rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key
psk bool
messagePatterns [][]MessagePattern
shouldWrite bool
msgIdx int
rng io.Reader
}
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern HandshakePattern, initiator bool, prologue, presharedKey []byte, newS, newE *DHKey, newRS, newRE []byte) *HandshakeState {
hs := &HandshakeState{
rs: newRS,
re: newRE,
psk: len(presharedKey) > 0,
messagePatterns: newHandshakePattern.Messages,
shouldWrite: initiator,
rng: rng,
......@@ -140,8 +142,15 @@ func NewHandshakeState(cs CipherSuite, rng io.Reader, newHandshakePattern Handsh
if newS != nil {
hs.s = *newS
}
hs.InitializeSymmetric([]byte("Noise_" + newHandshakePattern.Name + "_" + string(cs.Name())))
namePrefix := "Noise_"
if hs.psk {
namePrefix = "NoisePSK_"
}
hs.InitializeSymmetric([]byte(namePrefix + newHandshakePattern.Name + "_" + string(cs.Name())))
hs.MixHash(prologue)
if hs.psk {
hs.MixHash(presharedKey)
}
for _, m := range newHandshakePattern.InitiatorPreMessages {
switch {
case initiator && m == MessagePatternS:
......@@ -184,7 +193,11 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
switch msg {
case MessagePatternE:
s.e = s.cs.GenerateKeypair(s.rng)
out = s.EncryptAndHash(out, s.e.Public)
out = append(out, s.e.Public...)
s.MixHash(s.e.Public)
if s.psk {
s.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
panic("noise: invalid state, s.Public is nil")
......@@ -227,7 +240,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
switch msg {
case MessagePatternE, MessagePatternS:
expected := s.cs.DHLen()
if s.hasKey {
if msg == MessagePatternS && s.hasKey {
expected += 16
}
if len(message) < expected {
......@@ -235,7 +248,15 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
}
switch msg {
case MessagePatternE:
s.re, err = s.DecryptAndHash(s.re[:0], message[:expected])
if cap(s.re) < s.cs.DHLen() {
s.re = make([]byte, s.cs.DHLen())
}
s.re = s.re[:s.cs.DHLen()]
copy(s.re, message)
s.MixHash(s.re)
if s.psk {
s.MixKey(s.re)
}
case MessagePatternS:
if len(s.rs) > 0 {
panic("noise: invalid state, rs is not nil")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment