Commit 7e398aa7 authored by Ryan Huber's avatar Ryan Huber Committed by Jonathan Rudenberg

Replace panics with errors (#24)

* remove panics per PanicAndRecover guidance from go authors

* revert constructors to panic()

* fix vectorgen

* fix Write call too
parent f9b5bb75
......@@ -26,7 +26,7 @@ type DHKey struct {
type DHFunc interface {
// GenerateKeypair generates a new keypair using random as a source of
// entropy.
GenerateKeypair(random io.Reader) DHKey
GenerateKeypair(random io.Reader) (DHKey, error)
// DH performs a Diffie-Hellman calculation between the provided private and
// public keys and returns the result.
......@@ -104,16 +104,16 @@ var DH25519 DHFunc = dh25519{}
type dh25519 struct{}
func (dh25519) GenerateKeypair(rng io.Reader) DHKey {
func (dh25519) GenerateKeypair(rng io.Reader) (DHKey, error) {
var pubkey, privkey [32]byte
if rng == nil {
rng = rand.Reader
}
if _, err := io.ReadFull(rng, privkey[:]); err != nil {
panic(err)
return DHKey{}, err
}
curve25519.ScalarBaseMult(&pubkey, &privkey)
return DHKey{Private: privkey[:], Public: pubkey[:]}
return DHKey{Private: privkey[:], Public: pubkey[:]}, nil
}
func (dh25519) DH(privkey, pubkey []byte) []byte {
......
This diff is collapsed.
......@@ -262,7 +262,7 @@ type Config struct {
}
// NewHandshakeState starts a new handshake using the provided configuration.
func NewHandshakeState(c Config) *HandshakeState {
func NewHandshakeState(c Config) (*HandshakeState, error) {
hs := &HandshakeState{
s: c.StaticKeypair,
e: c.EphemeralKeypair,
......@@ -284,7 +284,7 @@ func NewHandshakeState(c Config) *HandshakeState {
pskModifier := ""
if len(hs.psk) > 0 {
if len(hs.psk) != 32 {
panic("noise: specification mandates 256-bit preshared keys")
return nil, errors.New("noise: specification mandates 256-bit preshared keys")
}
pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
......@@ -320,7 +320,7 @@ func NewHandshakeState(c Config) *HandshakeState {
hs.ss.MixHash(hs.re)
}
}
return hs
return hs, nil
}
// WriteMessage appends a handshake message to out. The message will include the
......@@ -329,21 +329,25 @@ func NewHandshakeState(c Config) *HandshakeState {
// remote peer, the other is used for decryption of messages from the remote
// peer. It is an error to call this method out of sync with the handshake
// pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState) {
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite {
panic("noise: unexpected call to WriteMessage should be ReadMessage")
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
panic("noise: no handshake messages left")
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
panic("noise: message is too long")
return nil, nil, nil, errors.New("noise: message is too long")
}
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
s.e = s.ss.cs.GenerateKeypair(s.rng)
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
......@@ -351,7 +355,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
}
case MessagePatternS:
if len(s.s.Public) == 0 {
panic("noise: invalid state, s.Public is nil")
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out = s.ss.EncryptAndHash(out, s.s.Public)
case MessagePatternDHEE:
......@@ -380,10 +384,10 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2
return out, cs1, cs2, nil
}
return out, nil, nil
return out, nil, nil, nil
}
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
......@@ -396,10 +400,10 @@ var ErrShortMessage = errors.New("noise: message is too short")
// error to call this method out of sync with the handshake pattern.
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
if s.shouldWrite {
panic("noise: unexpected call to ReadMessage should be WriteMessage")
return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
panic("noise: no handshake messages left")
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
s.ss.Checkpoint()
......@@ -428,7 +432,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
}
case MessagePatternS:
if len(s.rs) > 0 {
panic("noise: invalid state, rs is not nil")
return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
}
s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
}
......
......@@ -127,11 +127,11 @@ func (NoiseSuite) TestVectors(c *C) {
switch string(splitLine[0]) {
case "init_static":
staticI = DH25519.GenerateKeypair(hexReader(splitLine[1]))
staticI, _ = DH25519.GenerateKeypair(hexReader(splitLine[1]))
case "resp_static":
staticR = DH25519.GenerateKeypair(hexReader(splitLine[1]))
staticR, _ = DH25519.GenerateKeypair(hexReader(splitLine[1]))
case "resp_ephemeral":
ephR = DH25519.GenerateKeypair(hexReader(splitLine[1]))
ephR, _ = DH25519.GenerateKeypair(hexReader(splitLine[1]))
case "handshake":
name = string(splitLine[1])
c.Log(name)
......@@ -188,7 +188,8 @@ func (NoiseSuite) TestVectors(c *C) {
configI.PresharedKey = psk
configR.PresharedKey = psk
}
hsI, hsR = NewHandshakeState(configI), NewHandshakeState(configR)
hsI, _ = NewHandshakeState(configI)
hsR, _ = NewHandshakeState(configR)
}
i, _ := strconv.Atoi(string(splitLine[0][4:5]))
......@@ -213,7 +214,7 @@ func (NoiseSuite) TestVectors(c *C) {
}
var msg, res []byte
msg, csW0, csW1 = writer.WriteMessage(nil, payload)
msg, csW0, csW1, _ = writer.WriteMessage(nil, payload)
c.Assert(fmt.Sprintf("%x", msg), Equals, string(splitLine[1]))
res, csR0, csR1, err = reader.ReadMessage(nil, msg)
c.Assert(err, IsNil)
......
......@@ -73,9 +73,9 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
psk = []byte("!verysecretverysecretverysecret!")
}
staticI := cs.GenerateKeypair(hexReader(key0))
staticR := cs.GenerateKeypair(hexReader(key1))
ephR := cs.GenerateKeypair(hexReader(key2))
staticI, _ := cs.GenerateKeypair(hexReader(key0))
staticR, _ := cs.GenerateKeypair(hexReader(key1))
ephR, _ := cs.GenerateKeypair(hexReader(key2))
configI := Config{
CipherSuite: cs,
......@@ -151,8 +151,8 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
fmt.Fprintf(out, "preshared_key=%x\n", psk)
}
hsI := NewHandshakeState(configI)
hsR := NewHandshakeState(configR)
hsI, _ := NewHandshakeState(configI)
hsR, _ := NewHandshakeState(configR)
var cs0, cs1 *CipherState
for i := range h.Messages {
......@@ -166,7 +166,7 @@ func writeHandshake(out io.Writer, cs CipherSuite, h HandshakePattern, pskPlacem
payload = fmt.Sprintf("test_msg_%d", i)
}
var msg []byte
msg, cs0, cs1 = writer.WriteMessage(nil, []byte(payload))
msg, cs0, cs1, _ = writer.WriteMessage(nil, []byte(payload))
_, _, _, err := reader.ReadMessage(nil, msg)
if err != nil {
panic(err)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment