Commit 80cc726e authored by David Stainton's avatar David Stainton

GC and fix BlockingSendUnreliableMessage memory leaks

parent 04d0a895
Pipeline #992 passed with stage
in 7 minutes and 7 seconds
......@@ -20,10 +20,14 @@ package client
import (
"bytes"
"io"
"sync"
"testing"
"time"
"github.com/katzenpost/client/config"
"github.com/katzenpost/client/constants"
"github.com/katzenpost/core/crypto/rand"
"github.com/katzenpost/core/utils"
"github.com/stretchr/testify/require"
)
......@@ -192,3 +196,91 @@ func TestDockerClientAsyncSendReceiveWithDecoyTraffic(t *testing.T) {
client.Shutdown()
client.Wait()
}
func TestDockerClientTestGarbageCollection(t *testing.T) {
require := require.New(t)
cfg, err := config.LoadFile("testdata/client.toml")
require.NoError(err)
cfg, linkKey := AutoRegisterRandomClient(cfg)
client, err := New(cfg)
require.NoError(err)
clientSession, err := client.NewSession(linkKey)
require.NoError(err)
msgID := [constants.MessageIDLength]byte{}
_, err = io.ReadFull(rand.Reader, msgID[:])
var msg = Message{
ID: &msgID,
IsBlocking: false,
SentAt: time.Now().AddDate(0, 0, -1),
ReplyETA: 10 * time.Second,
}
clientSession.surbIDMap.Store(msgID, &msg)
clientSession.garbageCollect()
_, ok := clientSession.surbIDMap.Load(msgID)
require.False(ok)
client.Shutdown()
client.Wait()
}
/*
WTF doesn't this work!?
func TestDockerClientTestIntegrationGarbageCollection(t *testing.T) {
require := require.New(t)
cfg, err := config.LoadFile("testdata/client.toml")
require.NoError(err)
cfg, linkKey := AutoRegisterRandomClient(cfg)
client, err := New(cfg)
require.NoError(err)
clientSession, err := client.NewSession(linkKey)
require.NoError(err)
desc, err := clientSession.GetService("loop")
require.NoError(err)
// Send a message to a nonexistent service so that we don't get a reply and thus
// retain an entry in the SURB ID Map which we must garbage collect.
msgID, err := clientSession.SendUnreliableMessage("nonexistent", desc.Provider, []byte("hello"))
require.NoError(err)
t.Logf("sent message ID %x", msgID)
var wg sync.WaitGroup
wg.Add(1)
go func() {
for eventRaw := range clientSession.EventSink {
switch event := eventRaw.(type) {
case *MessageSentEvent:
if bytes.Equal(msgID[:], event.MessageID[:]) {
require.NoError(event.Err)
_, ok := clientSession.surbIDMap.Load(msgID)
require.True(ok)
duration := time.Duration(event.ReplyETA + constants.RoundTripTimeSlop + (5 * time.Second))
t.Logf("Sleeping for %s so that the SURB ID Map entry will get garbage collected.", duration)
time.Sleep(duration)
wg.Done()
return
}
default:
continue
}
}
}()
wg.Wait()
clientSession.garbageCollect()
_, ok := clientSession.surbIDMap.Load(msgID)
require.False(ok)
client.Shutdown()
client.Wait()
}
*/
......@@ -16,18 +16,16 @@
package constants
const (
// MessageIDLength is the length of a message ID in bytes.
MessageIDLength = 16
import (
"time"
)
const (
// SurbTypeACK is used to denote an ACK in response to a forward message.
SurbTypeACK = 0
// SurbTypeKaetzchen is used to denote a mixnet service query response.
SurbTypeKaetzchen = 1
// MessageIDLength is the length of a message ID in bytes.
MessageIDLength = 16
// SurbTypeInternal is used to reserve an internal SURB reply type.
SurbTypeInternal = 2
// RoundTripTimeSlop is the slop added to the expected packet
// round trip timeout threshold. Used for GC and for blocking
// on reply in Session's BlockingSendUnreliableMessage method.
RoundTripTimeSlop = time.Duration(8 * time.Second)
)
......@@ -19,6 +19,7 @@ package client
import (
"encoding/hex"
"fmt"
"time"
cConstants "github.com/katzenpost/client/constants"
)
......@@ -74,6 +75,12 @@ type MessageSentEvent struct {
// when the message was enqueued.
MessageID *[cConstants.MessageIDLength]byte
// SentAt contains the time the message was sent.
SentAt time.Time
// ReplyETA is the expected round trip time to receive a response.
ReplyETA time.Duration
// Err is the error encountered when sending the message if any.
Err error
}
......
......@@ -44,9 +44,6 @@ type Message struct {
// SentAt contains the time the message was sent.
SentAt time.Time
// Sent is set to true if the message was sent on the network.
Sent bool
// ReplyETA is the expected round trip time to receive a response.
ReplyETA time.Duration
......@@ -63,9 +60,6 @@ type Message struct {
// Reply is the SURB reply
Reply []byte
// SURBType is the SURB type.
SURBType int
// WithSURB specified if a SURB should be bundled with the forward payload.
WithSURB bool
......
......@@ -30,8 +30,6 @@ import (
sConstants "github.com/katzenpost/core/sphinx/constants"
)
const roundTripTimeSlop = time.Duration(88 * time.Second)
var ReplyTimeoutError = errors.New("Failure waiting for reply, timeout reached")
func (s *Session) sendNext() {
......@@ -80,7 +78,6 @@ func (s *Session) doSend(msg *Message) {
s.log.Debugf("doSend setting ReplyETA to %v", eta)
msg.Key = key
msg.SentAt = time.Now()
msg.Sent = true
msg.ReplyETA = eta
s.surbIDMap.Store(surbID, msg)
}
......@@ -98,6 +95,8 @@ func (s *Session) doSend(msg *Message) {
s.eventCh.In() <- &MessageSentEvent{
MessageID: msg.ID,
Err: err,
SentAt: msg.SentAt,
ReplyETA: msg.ReplyETA,
}
}
}
......@@ -152,7 +151,6 @@ func (s *Session) composeMessage(recipient, provider string, message []byte, isB
Provider: provider,
Payload: payload[:],
WithSURB: true,
SURBType: cConstants.SurbTypeKaetzchen,
IsBlocking: isBlocking,
}
return &msg, nil
......@@ -178,23 +176,25 @@ func (s *Session) BlockingSendUnreliableMessage(recipient, provider string, mess
}
sentWaitChan := make(chan *Message)
s.sentWaitChanMap.Store(*msg.ID, sentWaitChan)
defer s.sentWaitChanMap.Delete(*msg.ID)
replyWaitChan := make(chan []byte)
s.replyWaitChanMap.Store(*msg.ID, replyWaitChan)
defer s.replyWaitChanMap.Delete(*msg.ID)
err = s.egressQueue.Push(msg)
if err != nil {
return nil, err
}
// wait until sent so that we know the ReplyETA for the waiting below
sentMessage := <-sentWaitChan
s.sentWaitChanMap.Delete(*msg.ID)
// wait for reply or round trip timeout
select {
case reply := <-replyWaitChan:
s.replyWaitChanMap.Delete(*msg.ID)
return reply, nil
case <-time.After(sentMessage.ReplyETA + roundTripTimeSlop):
case <-time.After(sentMessage.ReplyETA + cConstants.RoundTripTimeSlop):
return nil, ReplyTimeoutError
}
// unreachable
......
......@@ -127,6 +127,7 @@ func NewSession(ctx context.Context, fatalErrCh chan error, logBackend *log.Back
}
s.Go(s.eventSinkWorker)
s.Go(s.garbageCollectionWorker)
s.minclient, err = minclient.New(clientCfg)
if err != nil {
......@@ -161,6 +162,43 @@ func (s *Session) eventSinkWorker() {
}
}
func (s *Session) garbageCollectionWorker() {
const garbageCollectionInterval = 10 * time.Minute
timer := time.NewTimer(garbageCollectionInterval)
defer timer.Stop()
for {
select {
case <-s.HaltCh():
s.log.Debugf("Garbage collection worker terminating gracefully.")
return
case <-timer.C:
s.garbageCollect()
timer.Reset(garbageCollectionInterval)
}
}
}
func (s *Session) garbageCollect() {
s.log.Debug("Running garbage collection process.")
// [sConstants.SURBIDLength]byte -> *Message
surbIDMapRange := func(rawSurbID, rawMessage interface{}) bool {
surbID := rawSurbID.([sConstants.SURBIDLength]byte)
message := rawMessage.(*Message)
if message.IsBlocking {
// Blocking sends don't need this garbage collection mechanism
// because the BlockingSendUnreliableMessage method will clean up
// after itself.
return true
}
if time.Now().After(message.SentAt.Add(message.ReplyETA).Add(cConstants.RoundTripTimeSlop)) {
s.log.Debug("Garbage collecting SURB ID Map entry for Message ID %x", message.ID)
s.surbIDMap.Delete(surbID)
}
return true
}
s.surbIDMap.Range(surbIDMapRange)
}
func (s *Session) awaitFirstPKIDoc(ctx context.Context) (*pki.Document, error) {
for {
var qo workerOp
......@@ -260,26 +298,22 @@ func (s *Session) onACK(surbID *[sConstants.SURBIDLength]byte, ciphertext []byte
s.decrementDecoyLoopTally()
return nil
}
switch msg.SURBType {
case cConstants.SurbTypeKaetzchen, cConstants.SurbTypeInternal:
if msg.IsBlocking {
replyWaitChanRaw, ok := s.replyWaitChanMap.Load(*msg.ID)
replyWaitChan := replyWaitChanRaw.(chan []byte)
if !ok {
err := fmt.Errorf("Impossible failure to acquire replyWaitChan for message ID %x", msg.ID)
s.fatalErrCh <- err
return err
}
replyWaitChan <- plaintext[2:]
} else {
s.eventCh.In() <- &MessageReplyEvent{
MessageID: msg.ID,
Payload: plaintext[2:],
Err: nil,
}
if msg.IsBlocking {
replyWaitChanRaw, ok := s.replyWaitChanMap.Load(*msg.ID)
if !ok {
err := fmt.Errorf("BUG, failure to acquire replyWaitChan for message ID %x", msg.ID)
s.fatalErrCh <- err
return err
}
replyWaitChan := replyWaitChanRaw.(chan []byte)
replyWaitChan <- plaintext[2:]
} else {
s.eventCh.In() <- &MessageReplyEvent{
MessageID: msg.ID,
Payload: plaintext[2:],
Err: nil,
}
default:
s.log.Warningf("Discarding SURB %v: Unknown type: 0x%02x", idStr, msg.SURBType)
}
return nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment