Commit c90ff622 authored by Jonathan Rudenberg's avatar Jonathan Rudenberg Committed by GitHub

Merge pull request #11 from zx2c4/psk-rework

psk mode: update for redesign
parents 69027979 bac779d5
......@@ -5,13 +5,19 @@ import (
"hash"
)
func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte) ([]byte, []byte) {
func hkdf(h func() hash.Hash, outputs int, out1, out2, out3, chainingKey, inputKeyMaterial []byte) ([]byte, []byte, []byte) {
if len(out1) > 0 {
panic("len(out1) > 0")
}
if len(out2) > 0 {
panic("len(out2) > 0")
}
if len(out3) > 0 {
panic("len(out3) > 0")
}
if outputs > 3 {
panic("outputs > 3")
}
tempMAC := hmac.New(h, chainingKey)
tempMAC.Write(inputKeyMaterial)
......@@ -21,10 +27,23 @@ func hkdf(h func() hash.Hash, out1, out2, chainingKey, inputKeyMaterial []byte)
out1MAC.Write([]byte{0x01})
out1 = out1MAC.Sum(out1)
if outputs == 1 {
return out1, nil, nil
}
out2MAC := hmac.New(h, tempKey)
out2MAC.Write(out1)
out2MAC.Write([]byte{0x02})
out2 = out2MAC.Sum(tempKey[:0])
out2 = out2MAC.Sum(out2)
if outputs == 2 {
return out1, out2, nil
}
out3MAC := hmac.New(h, tempKey)
out3MAC.Write(out2)
out3MAC.Write([]byte{0x03})
out3 = out3MAC.Sum(out3)
return out1, out2
return out1, out2, out3
}
......@@ -245,7 +245,7 @@ func (NoiseSuite) TestXXRoundtrip(c *C) {
c.Assert(string(res), Equals, "worri")
}
func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
func (NoiseSuite) Test_NNpsk0_Roundtrip(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashBLAKE2b)
rngI := new(RandomInc)
rngR := new(RandomInc)
......@@ -256,13 +256,13 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
Random: rngI,
Pattern: HandshakeNN,
Initiator: true,
PresharedKey: []byte("supersecret"),
PresharedKey: []byte("supersecretsupersecretsupersecre"),
})
hsR := NewHandshakeState(Config{
CipherSuite: cs,
Random: rngR,
Pattern: HandshakeNN,
PresharedKey: []byte("supersecret"),
PresharedKey: []byte("supersecretsupersecretsupersecre"),
})
// -> e
......@@ -292,7 +292,7 @@ func (NoiseSuite) TestPSK_NN_Roundtrip(c *C) {
c.Assert(string(res), Equals, "bar")
}
func (NoiseSuite) TestPSK_N(c *C) {
func (NoiseSuite) Test_Npsk0(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rng := new(RandomInc)
staticR := cs.GenerateKeypair(rng)
......@@ -302,18 +302,18 @@ func (NoiseSuite) TestPSK_N(c *C) {
Random: rng,
Pattern: HandshakeN,
Initiator: true,
PresharedKey: []byte{0x01, 0x02, 0x03},
PresharedKey: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20},
PeerStatic: staticR.Public,
})
msg, _, _ := hsI.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 48)
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd16625475344a60649da3ec23ce8e3ed779e766")
expected, _ := hex.DecodeString("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd1662542044ae563929068930dcf04674526cb9")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestPSK_X(c *C) {
func (NoiseSuite) Test_Xpsk0(c *C) {
cs := NewCipherSuite(DH25519, CipherChaChaPoly, HashSHA256)
rng := new(RandomInc)
staticI := cs.GenerateKeypair(rng)
......@@ -324,24 +324,24 @@ func (NoiseSuite) TestPSK_X(c *C) {
Random: rng,
Pattern: HandshakeX,
Initiator: true,
PresharedKey: []byte{0x01, 0x02, 0x03},
PresharedKey: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20},
StaticKeypair: staticI,
PeerStatic: staticR.Public,
})
msg, _, _ := hs.WriteMessage(nil, nil)
c.Assert(msg, HasLen, 96)
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51a12d5cf01bc576e8f0124b14db3ed7a00d20f16186e8f1e2c861fb3d4113f39b290f0048404b8d21e2098958b6bdf50f41dfb1143700310482cfb52c9002261bd")
expected, _ := hex.DecodeString("79a631eede1bf9c98f12032cdeadd0e7a079398fc786b88cc846ec89af85a51ad51eef529db0dd9127d4aa59a9183e118337d75a4e55e7e00f85c3d20ede536dd0112eec8c3b2a514018a90ab685b027dd24aa0c70b0c0f00524cc23785028b9")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestPSK_NN(c *C) {
func (NoiseSuite) Test_NNpsk0(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA512)
rngI := new(RandomInc)
rngR := new(RandomInc)
*rngR = 1
prologue := []byte{0x01, 0x02, 0x03}
psk := []byte{0x04, 0x05, 0x06}
psk := []byte{0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23}
hsI := NewHandshakeState(Config{
CipherSuite: cs,
......@@ -371,11 +371,11 @@ func (NoiseSuite) TestPSK_NN(c *C) {
c.Assert(err, IsNil)
c.Assert(string(res), Equals, "defg")
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7cfda657b21e8eac78df67b6bd453c0b11372364a6")
expected, _ := hex.DecodeString("07a37cbc142093c8b755dc1b10e86cb426374ad16aa853ed0bdfc0b2b86d1c7c3e42e140cfffbcdf5d9d2a1c24ce4cdbdf1eaf37")
c.Assert(msg, DeepEquals, expected)
}
func (NoiseSuite) TestPSK_XX(c *C) {
func (NoiseSuite) Test_XXpsk0(c *C) {
cs := NewCipherSuite(DH25519, CipherAESGCM, HashSHA256)
rngI := new(RandomInc)
rngR := new(RandomInc)
......@@ -384,7 +384,7 @@ func (NoiseSuite) TestPSK_XX(c *C) {
staticI := cs.GenerateKeypair(rngI)
staticR := cs.GenerateKeypair(rngR)
prologue := []byte{0x01, 0x02, 0x03}
psk := []byte{0x04, 0x05, 0x06}
psk := []byte{0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23}
hsI := NewHandshakeState(Config{
CipherSuite: cs,
......@@ -422,7 +422,7 @@ func (NoiseSuite) TestPSK_XX(c *C) {
c.Assert(err, IsNil)
c.Assert(res, HasLen, 0)
expected, _ := hex.DecodeString("2b9c628158a517e3984dc619245d4b9cd73561944f266181b183812ca73499881e30f6e7eeb576c258acc713c2c62874fd1beb76b122f6303f974109aefd7e2a")
expected, _ := hex.DecodeString("1b6d7cc3b13bd02217f9cdb98c50870db96281193dca4df570bf6230a603b686fd90d2914c7e797d9276ef8fb34b0c9d87faa048ce4bc7e7af21b6a450352275")
c.Assert(msg, DeepEquals, expected)
}
......
......@@ -9,6 +9,7 @@ package noise
import (
"crypto/rand"
"errors"
"fmt"
"io"
)
......@@ -63,7 +64,6 @@ func (s *CipherState) Cipher() Cipher {
type symmetricState struct {
CipherState
hasK bool
hasPSK bool
ck []byte
h []byte
......@@ -88,7 +88,7 @@ func (s *symmetricState) MixKey(dhOutput []byte) {
s.n = 0
s.hasK = true
var hk []byte
s.ck, hk = hkdf(s.cs.Hash, s.ck[:0], s.k[:0], s.ck, dhOutput)
s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k)
}
......@@ -100,11 +100,15 @@ func (s *symmetricState) MixHash(data []byte) {
s.h = h.Sum(s.h[:0])
}
func (s *symmetricState) MixPresharedKey(presharedKey []byte) {
func (s *symmetricState) MixKeyAndHash(data []byte) {
var hk []byte
var temp []byte
s.ck, temp = hkdf(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey)
s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
s.MixHash(temp)
s.hasPSK = true
copy(s.k[:], hk)
s.c = s.cs.Cipher(s.k)
s.n = 0
s.hasK = true
}
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte {
......@@ -132,7 +136,7 @@ func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
func (s *symmetricState) Split() (*CipherState, *CipherState) {
s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
hk1, hk2 := hkdf(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil)
hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
copy(s1.k[:], hk1)
copy(s2.k[:], hk2)
s1.c = s.cs.Cipher(s1.k)
......@@ -180,6 +184,7 @@ const (
MessagePatternDHES
MessagePatternDHSE
MessagePatternDHSS
MessagePatternPSK
)
// MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
......@@ -194,6 +199,7 @@ type HandshakeState struct {
e DHKey // local ephemeral keypair
rs []byte // remote party's static public key
re []byte // remote party's ephemeral public key
psk []byte // preshared key, maybe zero length
messagePatterns [][]MessagePattern
shouldWrite bool
msgIdx int
......@@ -221,9 +227,13 @@ type Config struct {
// be identical on both sides for the handshake to succeed.
Prologue []byte
// PresharedKey is the optional pre-shared key for the handshake.
// PresharedKey is the optional preshared key for the handshake.
PresharedKey []byte
// PresharedKeyPlacement specifies the placement position of the PSK token
// when PresharedKey is specified
PresharedKeyPlacement int
// StaticKeypair is this peer's static keypair, required if part of the
// handshake.
StaticKeypair DHKey
......@@ -247,6 +257,7 @@ func NewHandshakeState(c Config) *HandshakeState {
s: c.StaticKeypair,
e: c.EphemeralKeypair,
rs: c.PeerStatic,
psk: c.PresharedKey,
messagePatterns: c.Pattern.Messages,
shouldWrite: c.Initiator,
rng: c.Random,
......@@ -259,15 +270,21 @@ func NewHandshakeState(c Config) *HandshakeState {
copy(hs.re, c.PeerEphemeral)
}
hs.ss.cs = c.CipherSuite
namePrefix := "Noise_"
if len(c.PresharedKey) > 0 {
namePrefix = "NoisePSK_"
pskModifier := ""
if len(hs.psk) > 0 {
if len(hs.psk) != 32 {
panic("noise: specification mandates 256-bit preshared keys")
}
pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
if (c.PresharedKeyPlacement == 0) {
hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
} else {
hs.messagePatterns[c.PresharedKeyPlacement - 1] = append(hs.messagePatterns[c.PresharedKeyPlacement - 1], MessagePatternPSK)
}
}
hs.ss.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.ss.cs.Name())))
hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
hs.ss.MixHash(c.Prologue)
if len(c.PresharedKey) > 0 {
hs.ss.MixPresharedKey(c.PresharedKey)
}
for _, m := range c.Pattern.InitiatorPreMessages {
switch {
case c.Initiator && m == MessagePatternS:
......@@ -318,7 +335,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
s.e = s.ss.cs.GenerateKeypair(s.rng)
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if s.ss.hasPSK {
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
......@@ -334,6 +351,8 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re))
case MessagePatternDHSS:
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
s.shouldWrite = false
......@@ -385,7 +404,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.re = s.re[:s.ss.cs.DHLen()]
copy(s.re, message)
s.ss.MixHash(s.re)
if s.ss.hasPSK {
if len(s.psk) > 0 {
s.ss.MixKey(s.re)
}
case MessagePatternS:
......@@ -407,6 +426,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs))
case MessagePatternDHSS:
s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs))
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
out, err = s.ss.DecryptAndHash(out, message)
......
......@@ -139,11 +139,16 @@ func (NoiseSuite) TestVectors(c *C) {
configI, configR = Config{Initiator: true}, Config{}
hsI, hsR = nil, nil
components := strings.SplitN(name, "_", 5)
keyInfo = patternKeys[components[1]]
configI.Pattern = patterns[components[1]]
handshakeComponents := strings.Split(components[1], "psk")
if len(handshakeComponents) == 2 {
configI.PresharedKeyPlacement, _ = strconv.Atoi(handshakeComponents[1])
}
keyInfo = patternKeys[handshakeComponents[0]]
configI.Pattern = patterns[handshakeComponents[0]]
configI.CipherSuite = NewCipherSuite(DH25519, ciphers[components[3]], hashes[components[4]])
configR.Pattern = configI.Pattern
configR.CipherSuite = configI.CipherSuite
configR.PresharedKeyPlacement = configI.PresharedKeyPlacement
case "gen_init_ephemeral":
configI.Random = hexReader(splitLine[1])
case "gen_resp_ephemeral":
......@@ -180,7 +185,7 @@ func (NoiseSuite) TestVectors(c *C) {
configR.EphemeralKeypair = ephR
configI.PeerEphemeral = ephR.Public
}
if strings.HasPrefix(name, "NoisePSK_") {
if strings.Index(name, "psk") != -1 {
configI.PresharedKey = psk
configR.PresharedKey = psk
}
......
......@@ -11,9 +11,9 @@ import (
)
func main() {
for ci, cipher := range []CipherFunc{CipherAESGCM, CipherChaChaPoly} {
for _, cipher := range []CipherFunc{CipherAESGCM, CipherChaChaPoly} {
for _, hash := range []HashFunc{HashSHA256, HashSHA512, HashBLAKE2b, HashBLAKE2s} {
for hi, handshake := range []HandshakePattern{
for _, handshake := range []HandshakePattern{
HandshakeNN,
HandshakeKN,
HandshakeNK,
......@@ -31,15 +31,18 @@ func main() {
HandshakeX,
HandshakeXR,
} {
for _, psk := range []bool{false, true} {
payloads := (psk && hi%2 == 0) || (!psk && hi%2 != 0)
prologue := ci == 0
writeHandshake(
os.Stdout,
NewCipherSuite(DH25519, cipher, hash),
handshake, psk, prologue, payloads,
)
fmt.Fprintln(os.Stdout)
for _, prologue := range []bool{false, true} {
for _, payloads := range []bool{false, true} {
for pskPlacement := -1; pskPlacement <= len(handshake.Messages); pskPlacement++ {
writeHandshake(
os.Stdout,
NewCipherSuite(DH25519, cipher, hash),
handshake, pskPlacement,
pskPlacement >= 0, prologue, payloads,
)
fmt.Fprintln(os.Stdout)
}
}
}
}
}
......@@ -62,13 +65,13 @@ const (
key4 = "4142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60"
)
func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, hasPSK, hasPrologue, payloads bool) {
func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacement int, hasPSK, hasPrologue, payloads bool) {
var prologue, psk []byte
if hasPrologue {
prologue = []byte("notsecret")
}
if hasPSK {
psk = []byte("verysecret")
psk = []byte("!verysecretverysecretverysecret!")
}
staticI := cs.GenerateKeypair(hexReader(key0))
......@@ -82,6 +85,7 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, hasPSK, h
Initiator: true,
Prologue: prologue,
PresharedKey: psk,
PresharedKeyPlacement: pskPlacement,
}
configR := configI
configR.Random = hexReader(key4)
......@@ -89,10 +93,10 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, hasPSK, h
var pskName string
if hasPSK {
pskName = "PSK"
pskName = fmt.Sprintf("psk%d", pskPlacement)
}
fmt.Fprintf(out, "handshake=Noise%s_%s_%s\n", pskName, h.Name, cs.Name())
fmt.Fprintf(out, "handshake=Noise_%s%s_%s\n", h.Name, pskName, cs.Name())
if len(h.Name) == 1 {
switch h.Name {
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment