diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index a76722814..b908e06eb 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -394,8 +394,9 @@ type Client struct { logger slog.Logger - testTickC chan struct{} - testHooks map[hook]hookFunc + testWaiting chan struct{} + testTickC chan time.Time + testHooks map[hook]hookFunc } // NewClient creates a wallet's mixing client manager. @@ -668,26 +669,28 @@ func (c *Client) waitForEpoch(ctx context.Context) (time.Time, error) { now := time.Now().UTC() epoch := now.Truncate(c.epoch).Add(c.epoch) duration := epoch.Sub(now) - timer := time.NewTimer(duration) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return epoch, ctx.Err() - case <-c.stopping: - if !timer.Stop() { - <-timer.C - } - c.runWG.Wait() - return epoch, ErrStopping - case <-c.testTickC: - if !timer.Stop() { - <-timer.C + + testWaiting := c.testWaiting + var timerC, testTickC <-chan time.Time + if testWaiting == nil { + timerC = time.After(duration) + } + + for { + select { + case <-ctx.Done(): + return epoch, ctx.Err() + case <-c.stopping: + c.runWG.Wait() + return epoch, ErrStopping + case <-timerC: + return epoch, nil + case testWaiting <- struct{}{}: + testWaiting = nil + testTickC = c.testTickC + case testEpoch := <-testTickC: + return testEpoch, nil } - return time.Now(), nil - case <-timer.C: - return epoch, nil } } @@ -700,6 +703,11 @@ func (p *peer) msgJitter() time.Duration { // small amount of jitter is added to help avoid timing deanonymization // attacks. func (c *Client) prDelay(ctx context.Context, p *peer) error { + // No delay in tests. + if c.testTickC != nil { + return nil + } + now := time.Now().UTC() epoch := now.Truncate(c.epoch).Add(c.epoch) sendBefore := epoch.Add(-timeoutDuration - maxJitter) @@ -717,18 +725,13 @@ func (c *Client) prDelay(ctx context.Context, p *peer) error { <-timer.C } return ctx.Err() - case <-c.testTickC: - if !timer.Stop() { - <-timer.C - } - return nil case <-timer.C: return nil } } -func (c *Client) testTick() { - c.testTickC <- struct{}{} +func (c *Client) testTick(t time.Time) { + c.testTickC <- t } func (c *Client) testHook(stage hook, ps *pairedSessions, s *sessionRun, p *peer) { @@ -858,12 +861,23 @@ func (c *Client) epochTicker(ctx context.Context) error { if err != nil { return err } + timerC := time.After(timeoutDuration + 2*time.Second) + testWaiting := c.testWaiting + var testTickC <-chan time.Time + for { + select { + case <-ctx.Done(): + return err - select { - case <-time.After(timeoutDuration + 2*time.Second): - case <-c.testTickC: - case <-ctx.Done(): - return err + case <-timerC: + + case testWaiting <- struct{}{}: + testWaiting = nil + testTickC = c.testTickC + continue + case <-testTickC: + } + break } c.mu.Lock() c.removeUnresponsiveDuringEpoch(prevPRs, uint64(firstEpoch.Unix())) diff --git a/mixing/mixclient/client_test.go b/mixing/mixclient/client_test.go index c03932bf0..012dece7a 100644 --- a/mixing/mixclient/client_test.go +++ b/mixing/mixclient/client_test.go @@ -45,7 +45,8 @@ const ( func newTestClient(w *testWallet, logger slog.Logger) *Client { c := NewClient(w) - c.testTickC = make(chan struct{}) + c.testWaiting = make(chan struct{}) + c.testTickC = make(chan time.Time) c.SetLogger(logger) return c } @@ -223,7 +224,8 @@ func TestHonest(t *testing.T) { <-doneRun }() - c.testTick() + <-c.testWaiting + c.testTick(time.Now().Truncate(time.Second)) var g errgroup.Group for i := range peers { @@ -235,7 +237,8 @@ func TestHonest(t *testing.T) { go func() { for { - c.testTick() + <-c.testWaiting + c.testTick(time.Now().Truncate(time.Second)) select { case <-ctx.Done(): return @@ -329,8 +332,11 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { }() testTick := func() { - c.testTick() - c2.testTick() + <-c.testWaiting + <-c2.testWaiting + t := time.Now().Truncate(time.Second) + c.testTick(t) + c2.testTick(t) } testTick() diff --git a/mixing/mixclient/log_test.go b/mixing/mixclient/log_test.go index a674e8fd9..20bcd660c 100644 --- a/mixing/mixclient/log_test.go +++ b/mixing/mixclient/log_test.go @@ -32,6 +32,7 @@ func useTestLogger(t *testing.T) (slog.Logger, func()) { l.SetLevel(slog.LevelTrace) mixpool.UseLogger(l) return l, func() { + l.SetLevel(slog.LevelOff) mixpool.UseLogger(slog.Disabled) } }