Commit 7e398aa7 authored by Ryan Huber's avatar Ryan Huber Committed by Jonathan Rudenberg

Replace panics with errors (#24)

* remove panics per PanicAndRecover guidance from go authors

* revert constructors to panic()

* fix vectorgen

* fix Write call too
parent f9b5bb75
...@@ -26,7 +26,7 @@ type DHKey struct { ...@@ -26,7 +26,7 @@ type DHKey struct {
type DHFunc interface { type DHFunc interface {
// GenerateKeypair generates a new keypair using random as a source of // GenerateKeypair generates a new keypair using random as a source of
// entropy. // entropy.
GenerateKeypair(random io.Reader) DHKey GenerateKeypair(random io.Reader) (DHKey, error)
// DH performs a Diffie-Hellman calculation between the provided private and // DH performs a Diffie-Hellman calculation between the provided private and
// public keys and returns the result. // public keys and returns the result.
...@@ -104,16 +104,16 @@ var DH25519 DHFunc = dh25519{} ...@@ -104,16 +104,16 @@ var DH25519 DHFunc = dh25519{}
type dh25519 struct{} type dh25519 struct{}
func (dh25519) GenerateKeypair(rng io.Reader) DHKey { func (dh25519) GenerateKeypair(rng io.Reader) (DHKey, error) {
var pubkey, privkey [32]byte var pubkey, privkey [32]byte
if rng == nil { if rng == nil {
rng = rand.Reader rng = rand.Reader
} }
if _, err := io.ReadFull(rng, privkey[:]); err != nil { if _, err := io.ReadFull(rng, privkey[:]); err != nil {
panic(err) return DHKey{}, err
} }
curve25519.ScalarBaseMult(&pubkey, &privkey) curve25519.ScalarBaseMult(&pubkey, &privkey)
return DHKey{Private: privkey[:], Public: pubkey[:]} return DHKey{Private: privkey[:], Public: pubkey[:]}, nil
} }
func (dh25519) DH(privkey, pubkey []byte) []byte { func (dh25519) DH(privkey, pubkey []byte) []byte {
......
This diff is collapsed.
...@@ -262,7 +262,7 @@ type Config struct { ...@@ -262,7 +262,7 @@ type Config struct {
} }
// NewHandshakeState starts a new handshake using the provided configuration. // NewHandshakeState starts a new handshake using the provided configuration.
func NewHandshakeState(c Config) *HandshakeState { func NewHandshakeState(c Config) (*HandshakeState, error) {
hs := &HandshakeState{ hs := &HandshakeState{
s: c.StaticKeypair, s: c.StaticKeypair,
e: c.EphemeralKeypair, e: c.EphemeralKeypair,
...@@ -284,7 +284,7 @@ func NewHandshakeState(c Config) *HandshakeState { ...@@ -284,7 +284,7 @@ func NewHandshakeState(c Config) *HandshakeState {
pskModifier := "" pskModifier := ""
if len(hs.psk) > 0 { if len(hs.psk) > 0 {
if len(hs.psk) != 32 { if len(hs.psk) != 32 {
panic("noise: specification mandates 256-bit preshared keys") return nil, errors.New("noise: specification mandates 256-bit preshared keys")
} }
pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement) pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...) hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
...@@ -320,7 +320,7 @@ func NewHandshakeState(c Config) *HandshakeState { ...@@ -320,7 +320,7 @@ func NewHandshakeState(c Config) *HandshakeState {
hs.ss.MixHash(hs.re) hs.ss.MixHash(hs.re)
} }
} }
return hs return hs, nil
} }
// WriteMessage appends a handshake message to out. The message will include the // WriteMessage appends a handshake message to out. The message will include the
...@@ -329,21 +329,25 @@ func NewHandshakeState(c Config) *HandshakeState { ...@@ -329,21 +329,25 @@ func NewHandshakeState(c Config) *HandshakeState {
// remote peer, the other is used for decryption of messages from the remote // remote peer, the other is used for decryption of messages from the remote
// peer. It is an error to call this method out of sync with the handshake // peer. It is an error to call this method out of sync with the handshake
// pattern. // pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState) { func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite { if !s.shouldWrite {
panic("noise: unexpected call to WriteMessage should be ReadMessage") return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
} }
if s.msgIdx > len(s.messagePatterns)-1 { if s.msgIdx > len(s.messagePatterns)-1 {
panic("noise: no handshake messages left") return nil, nil, nil, errors.New("noise: no handshake messages left")
} }
if len(payload) > MaxMsgLen { if len(payload) > MaxMsgLen {
panic("noise: message is too long") return nil, nil, nil, errors.New("noise: message is too long")
} }
for _, msg := range s.messagePatterns[s.msgIdx] { for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg { switch msg {
case MessagePatternE: case MessagePatternE:
s.e = s.ss.cs.GenerateKeypair(s.rng) e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...) out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public) s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 { if len(s.psk) > 0 {
...@@ -351,7 +355,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState ...@@ -351,7 +355,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
} }
case MessagePatternS: case MessagePatternS:
if len(s.s.Public) == 0 { if len(s.s.Public) == 0 {
panic("noise: invalid state, s.Public is nil") return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
} }
out = s.ss.EncryptAndHash(out, s.s.Public) out = s.ss.EncryptAndHash(out, s.s.Public)
case MessagePatternDHEE: case MessagePatternDHEE:
...@@ -380,10 +384,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState ...@@ -380,10 +384,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
if s.msgIdx >= len(s.messagePatterns) { if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split() cs1, cs2 := s.ss.Split()
return out, cs1, cs2 return out, cs1, cs2, nil
} }
return out, nil, nil return out, nil, nil, nil
} }
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be. // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
...@@ -396,10 +400,10 @@ var ErrShortMessage = errors.New("noise: message is too short") ...@@ -396,10 +400,10 @@ var ErrShortMessage = errors.New("noise: message is too short")
// error to call this method out of sync with the handshake pattern. // error to call this method out of sync with the handshake pattern.
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
if s.shouldWrite { if s.shouldWrite {
panic("noise: unexpected call to ReadMessage should be WriteMessage") return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
} }
if s.msgIdx > len(s.messagePatterns)-1 { if s.msgIdx > len(s.messagePatterns)-1 {
panic("noise: no handshake messages left") return nil, nil, nil, errors.New("noise: no handshake messages left")
} }
s.ss.Checkpoint() s.ss.Checkpoint()
...@@ -428,7 +432,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, ...@@ -428,7 +432,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
} }
case MessagePatternS: case MessagePatternS:
if len(s.rs) > 0 { if len(s.rs) > 0 {
panic("noise: invalid state, rs is not nil") return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
} }
s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected]) s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
} }
......
...@@ -127,11 +127,11 @@ func (NoiseSuite) TestVectors(c *C) { ...@@ -127,11 +127,11 @@ func (NoiseSuite) TestVectors(c *C) {
switch string(splitLine[0]) { switch string(splitLine[0]) {
case "init_static": case "init_static":
staticI = DH25519.GenerateKeypair(hexReader(splitLine[1])) staticI, _ = DH25519.GenerateKeypair(hexReader(splitLine[1]))
case "resp_static": case "resp_static":
staticR = DH25519.GenerateKeypair(hexReader(splitLine[1])) staticR, _ = DH25519.GenerateKeypair(hexReader(splitLine[1]))
case "resp_ephemeral": case "resp_ephemeral":
ephR = DH25519.GenerateKeypair(hexReader(splitLine[1])) ephR, _ = DH25519.GenerateKeypair(hexReader(splitLine[1]))
case "handshake": case "handshake":
name = string(splitLine[1]) name = string(splitLine[1])
c.Log(name) c.Log(name)
...@@ -188,7 +188,8 @@ func (NoiseSuite) TestVectors(c *C) { ...@@ -188,7 +188,8 @@ func (NoiseSuite) TestVectors(c *C) {
configI.PresharedKey = psk configI.PresharedKey = psk
configR.PresharedKey = psk configR.PresharedKey = psk
} }
hsI, hsR = NewHandshakeState(configI), NewHandshakeState(configR) hsI, _ = NewHandshakeState(configI)
hsR, _ = NewHandshakeState(configR)
} }
i, _ := strconv.Atoi(string(splitLine[0][4:5])) i, _ := strconv.Atoi(string(splitLine[0][4:5]))
...@@ -213,7 +214,7 @@ func (NoiseSuite) TestVectors(c *C) { ...@@ -213,7 +214,7 @@ func (NoiseSuite) TestVectors(c *C) {
} }
var msg, res []byte var msg, res []byte
msg, csW0, csW1 = writer.WriteMessage(nil, payload) msg, csW0, csW1, _ = writer.WriteMessage(nil, payload)
c.Assert(fmt.Sprintf("%x", msg), Equals, string(splitLine[1])) c.Assert(fmt.Sprintf("%x", msg), Equals, string(splitLine[1]))
res, csR0, csR1, err = reader.ReadMessage(nil, msg) res, csR0, csR1, err = reader.ReadMessage(nil, msg)
c.Assert(err, IsNil) c.Assert(err, IsNil)
......
...@@ -73,9 +73,9 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem ...@@ -73,9 +73,9 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
psk = []byte("!verysecretverysecretverysecret!") psk = []byte("!verysecretverysecretverysecret!")
} }
staticI := cs.GenerateKeypair(hexReader(key0)) staticI, _ := cs.GenerateKeypair(hexReader(key0))
staticR := cs.GenerateKeypair(hexReader(key1)) staticR, _ := cs.GenerateKeypair(hexReader(key1))
ephR := cs.GenerateKeypair(hexReader(key2)) ephR, _ := cs.GenerateKeypair(hexReader(key2))
configI := Config{ configI := Config{
CipherSuite: cs, CipherSuite: cs,
...@@ -151,8 +151,8 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem ...@@ -151,8 +151,8 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
fmt.Fprintf(out, "preshared_key=%x\n", psk) fmt.Fprintf(out, "preshared_key=%x\n", psk)
} }
hsI := NewHandshakeState(configI) hsI, _ := NewHandshakeState(configI)
hsR := NewHandshakeState(configR) hsR, _ := NewHandshakeState(configR)
var cs0, cs1 *CipherState var cs0, cs1 *CipherState
for i := range h.Messages { for i := range h.Messages {
...@@ -166,7 +166,7 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem ...@@ -166,7 +166,7 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
payload = fmt.Sprintf("test_msg_%d", i) payload = fmt.Sprintf("test_msg_%d", i)
} }
var msg []byte var msg []byte
msg, cs0, cs1 = writer.WriteMessage(nil, []byte(payload)) msg, cs0, cs1, _ = writer.WriteMessage(nil, []byte(payload))
_, _, _, err := reader.ReadMessage(nil, msg) _, _, _, err := reader.ReadMessage(nil, msg)
if err != nil { if err != nil {
panic(err) panic(err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment