diff --git a/go.mod b/go.mod index 70aba7976f..1840436077 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/decred/dcrd/dcrutil/v4 v4.0.3 github.com/decred/dcrd/gcs/v4 v4.1.1 github.com/decred/dcrd/math/uint256 v1.0.2 - github.com/decred/dcrd/mixing v0.6.1 + github.com/decred/dcrd/mixing v0.7.0 github.com/decred/dcrd/peer/v3 v3.2.0 github.com/decred/dcrd/rpc/jsonrpc/types/v4 v4.4.0 github.com/decred/dcrd/rpcclient/v8 v8.1.0 diff --git a/go.sum b/go.sum index 65ae510e31..a46816844b 100644 --- a/go.sum +++ b/go.sum @@ -56,8 +56,8 @@ github.com/decred/dcrd/hdkeychain/v3 v3.1.3 h1:Kn2wfj5cOR6pQO/WrYOMT1KK12IgWFEeQ github.com/decred/dcrd/hdkeychain/v3 v3.1.3/go.mod h1:mDAuGaH6InRD+hKVeVJsjLD/ih1mD3aCKURNHS8Tq2s= github.com/decred/dcrd/math/uint256 v1.0.2 h1:o8peafL5QmuXGTergI3YDpDU0eq5Z0pQi88B8ym4PRA= github.com/decred/dcrd/math/uint256 v1.0.2/go.mod h1:7M/y9wJJvlyNG/f/X6mxxhxo9dgloZHFiOfbiscl75A= -github.com/decred/dcrd/mixing v0.6.1 h1:nysHraShVbgKIrqpa04364YTchmVtqgaphTNWqs7nHc= -github.com/decred/dcrd/mixing v0.6.1/go.mod h1:M+Ao9h49usdSEsC0N6kMafijwgYh7E57RTKh8cCw02o= +github.com/decred/dcrd/mixing v0.7.0 h1:rKs5/6E8szo561kKp54iXdcC7GlQBIE9IP3utgKxjAM= +github.com/decred/dcrd/mixing v0.7.0/go.mod h1:UE7+nPvXA6C7GdIFXJmz0ksmxykwn88zzbQFwAopahE= github.com/decred/dcrd/peer/v3 v3.2.0 h1:6PxoGcsyjV3me/WLjIZdmWs4Pma5jYGi40CmzNCzeNY= github.com/decred/dcrd/peer/v3 v3.2.0/go.mod h1:OUFvMboEQvnq4V8CRw2RUn/fFls0rpVuDN0QT0zq8RA= github.com/decred/dcrd/rpc/jsonrpc/types/v4 v4.4.0 h1:BBVaYemabsFsaqNVlCmacoZlSsLDqwHYIs8ty6fg59Q= diff --git a/internal/netsync/manager.go b/internal/netsync/manager.go index 72cd38934c..3160a65dac 100644 --- a/internal/netsync/manager.go +++ b/internal/netsync/manager.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2025 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -1056,7 +1056,8 @@ func (m *SyncManager) OnMixMsg(peer *Peer, msg mixing.Message) ([]mixing.Message return nil, nil } - accepted, err := m.cfg.MixPool.AcceptMessage(msg) + source := mixpool.Uint64Source(peer.ID()) + accepted, err := m.cfg.MixPool.AcceptMessage(msg, source) // Remove message from request maps. Either the mixpool already knows // about it and as such we shouldn't have any more instances of trying diff --git a/internal/rpcserver/interface.go b/internal/rpcserver/interface.go index 0316d1f5fe..e230331496 100644 --- a/internal/rpcserver/interface.go +++ b/internal/rpcserver/interface.go @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2025 The Decred developers +// Copyright (c) 2019-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -20,6 +20,7 @@ import ( "github.com/decred/dcrd/internal/mining" "github.com/decred/dcrd/math/uint256" "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/mixpool" "github.com/decred/dcrd/peer/v3" "github.com/decred/dcrd/rpc/jsonrpc/types/v4" "github.com/decred/dcrd/txscript/v4/stdaddr" @@ -200,9 +201,9 @@ type SyncManager interface { // This method may report a false positive, but never a false negative. RecentlyConfirmedTxn(hash *chainhash.Hash) bool - // SubmitMixMessage submits the mixing message to the network after - // processing it locally. - SubmitMixMessage(msg mixing.Message) error + // AcceptMixMessage attempts to accept a mixing message to the local mixing + // pool. + AcceptMixMessage(msg mixing.Message, src mixpool.Source) error } // UtxoEntry represents a utxo entry for use with the RPC server. diff --git a/internal/rpcserver/rpcserver.go b/internal/rpcserver/rpcserver.go index 5245369113..1a609ff0c5 100644 --- a/internal/rpcserver/rpcserver.go +++ b/internal/rpcserver/rpcserver.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2025 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -50,6 +50,7 @@ import ( "github.com/decred/dcrd/internal/mining/cpuminer" "github.com/decred/dcrd/internal/version" "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/mixpool" "github.com/decred/dcrd/rpc/jsonrpc/types/v4" "github.com/decred/dcrd/txscript/v4" "github.com/decred/dcrd/txscript/v4/stdaddr" @@ -4348,7 +4349,8 @@ func handleSendRawMixMessage(_ context.Context, s *Server, cmd interface{}) (int msg.WriteHash(s.blake256Hasher) s.blake256HaserMu.Unlock() - err = s.cfg.SyncMgr.SubmitMixMessage(msg) + // Use 0 for the source to represent the local node. + err = s.cfg.SyncMgr.AcceptMixMessage(msg, mixpool.ZeroSource) if err != nil { // XXX: consider a better error code/function str := fmt.Sprintf("Rejected mix message: %s", err) diff --git a/internal/rpcserver/rpcserverhandlers_test.go b/internal/rpcserver/rpcserverhandlers_test.go index bacfa0c7fb..e642d448a6 100644 --- a/internal/rpcserver/rpcserverhandlers_test.go +++ b/internal/rpcserver/rpcserverhandlers_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2025 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -44,6 +44,7 @@ import ( "github.com/decred/dcrd/internal/version" "github.com/decred/dcrd/math/uint256" "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/mixpool" "github.com/decred/dcrd/peer/v3" "github.com/decred/dcrd/rpc/jsonrpc/types/v4" "github.com/decred/dcrd/txscript/v4" @@ -572,7 +573,7 @@ func (c *testAddrManager) LocalAddresses() []addrmgr.LocalAddr { type testSyncManager struct { isCurrent bool submitBlockErr error - submitMixErr error + acceptMixErr error syncPeerID int32 syncHeight int64 processTransaction []*dcrutil.Tx @@ -592,10 +593,6 @@ func (s *testSyncManager) SubmitBlock(block *dcrutil.Block) error { return s.submitBlockErr } -func (s *testSyncManager) SubmitMixMessage(msg mixing.Message) error { - return s.submitMixErr -} - // SyncPeer returns a mocked id of the current peer being synced with. func (s *testSyncManager) SyncPeerID() int32 { return s.syncPeerID @@ -619,6 +616,12 @@ func (s *testSyncManager) RecentlyConfirmedTxn(hash *chainhash.Hash) bool { return s.recentlyConfirmedTxn } +// AcceptMixMessage provides a mock implementation for attempting to accept a +// mixing message to the local mixing pool. +func (s *testSyncManager) AcceptMixMessage(msg mixing.Message, src mixpool.Source) error { + return s.acceptMixErr +} + // testExistsAddresser provides a mock exists addresser by implementing the // ExistsAddresser interface. type testExistsAddresser struct { diff --git a/mixing/dcnet.go b/mixing/dcnet.go index 3091d711d5..f8541c8713 100644 --- a/mixing/dcnet.go +++ b/mixing/dcnet.go @@ -15,9 +15,8 @@ import ( // SRMixPads creates a vector of exponential DC-net pads from a vector of // shared secrets with each participating peer in the DC-net. func SRMixPads(kp [][]byte, my uint32) []*big.Int { - h := blake256.New() + h := blake256.NewHasher256() scratch := make([]byte, 8) - digest := make([]byte, blake256.Size) pads := make([]*big.Int, len(kp)) partialPad := new(big.Int) for j := uint32(0); j < uint32(len(kp)); j++ { @@ -30,8 +29,8 @@ func SRMixPads(kp [][]byte, my uint32) []*big.Int { h.Reset() h.Write(kp[i]) h.Write(scratch) - digest = h.Sum(digest[:0]) - partialPad.SetBytes(digest) + digest := h.Sum256() + partialPad.SetBytes(digest[:]) if my > i { pads[j].Add(pads[j], partialPad) } else { diff --git a/mixing/keyagreement.go b/mixing/keyagreement.go index 820d0c61d9..d2d02efd40 100644 --- a/mixing/keyagreement.go +++ b/mixing/keyagreement.go @@ -172,7 +172,7 @@ func (kx *KX) SharedSecrets(k *RevealedKeys, sid []byte, run uint32, mcounts []u // XOR ECDH and both sntrup4591761 keys into a single // shared key. If sntrup4591761 is discovered to be // broken in the future, the security only reduces to - // that of x25519. + // that of ECDH. // If the message belongs to our own peer, only XOR // the sntrup4591761 key once. The decapsulated and // cleartext keys are equal in this case, and would diff --git a/mixing/limits.go b/mixing/limits.go new file mode 100644 index 0000000000..01d298b45c --- /dev/null +++ b/mixing/limits.go @@ -0,0 +1,42 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixing + +// The wire package's limits on mixing messages are excessively large. In +// practice, much lower limits are defined below and used by mixpool and +// mixclient. These limits are enforced by the mixing module to avoid a +// wire protocol bump. +// +// These limits are a tradeoff between mix quality, theoretical maximum +// messages sizes, and the size of the resulting mix transaction. In the +// largest mix defined by these limits, 64 total peers may each spend five +// P2PKH outputs each while contributing 17 P2PKH outputs (16 mixed outputs +// plus one change output) with the estimated size of the resulting transaction +// equaling 92309 bytes, just under the 100k standard threshold. +// +// These constants may change at any point without a major API bump to the +// mixing module. +const ( + // MaxPeers is the maximum number of peers allowed together in a + // single mix session. This restricts the maximum dimensions of the + // slot reservation and XOR DC-net matrices and the maximum number of + // previous messages that may be referenced by mix messages. + MaxPeers = 64 + + // MaxMcount is the maximum number of mixed messages that any single + // peer can contribute to the mix. + MaxMcount = 16 + + // MaxMtot is the maximum number of mixed messages in total that can + // be created during a session (the sum of all mcounts). + MaxMtot = 1024 + + // MaxMixTxSerializeSize is the maximum size of a mix transaction. + // This value should not exceed mempool standardness limits. + MaxMixTxSerializeSize = 100000 + + // MaxMixAmount is the maximum value of a mixed output. + MaxMixAmount = 21000000_00000000 / MinPeers +) diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index 16d4644a20..b908e06eb9 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -394,8 +394,9 @@ type Client struct { logger slog.Logger - testTickC chan struct{} - testHooks map[hook]hookFunc + testWaiting chan struct{} + testTickC chan time.Time + testHooks map[hook]hookFunc } // NewClient creates a wallet's mixing client manager. @@ -417,7 +418,7 @@ func NewClient(w Wallet) *Client { height: height, warming: make(chan struct{}), workQueue: make(chan *queueWork, runtime.NumCPU()), - blake256Hasher: blake256.New(), + blake256Hasher: blake256.NewHasher256(), epoch: w.Mixpool().Epoch(), stopping: make(chan struct{}), } @@ -668,26 +669,28 @@ func (c *Client) waitForEpoch(ctx context.Context) (time.Time, error) { now := time.Now().UTC() epoch := now.Truncate(c.epoch).Add(c.epoch) duration := epoch.Sub(now) - timer := time.NewTimer(duration) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return epoch, ctx.Err() - case <-c.stopping: - if !timer.Stop() { - <-timer.C - } - c.runWG.Wait() - return epoch, ErrStopping - case <-c.testTickC: - if !timer.Stop() { - <-timer.C + + testWaiting := c.testWaiting + var timerC, testTickC <-chan time.Time + if testWaiting == nil { + timerC = time.After(duration) + } + + for { + select { + case <-ctx.Done(): + return epoch, ctx.Err() + case <-c.stopping: + c.runWG.Wait() + return epoch, ErrStopping + case <-timerC: + return epoch, nil + case testWaiting <- struct{}{}: + testWaiting = nil + testTickC = c.testTickC + case testEpoch := <-testTickC: + return testEpoch, nil } - return time.Now(), nil - case <-timer.C: - return epoch, nil } } @@ -700,6 +703,11 @@ func (p *peer) msgJitter() time.Duration { // small amount of jitter is added to help avoid timing deanonymization // attacks. func (c *Client) prDelay(ctx context.Context, p *peer) error { + // No delay in tests. + if c.testTickC != nil { + return nil + } + now := time.Now().UTC() epoch := now.Truncate(c.epoch).Add(c.epoch) sendBefore := epoch.Add(-timeoutDuration - maxJitter) @@ -717,18 +725,13 @@ func (c *Client) prDelay(ctx context.Context, p *peer) error { <-timer.C } return ctx.Err() - case <-c.testTickC: - if !timer.Stop() { - <-timer.C - } - return nil case <-timer.C: return nil } } -func (c *Client) testTick() { - c.testTickC <- struct{}{} +func (c *Client) testTick(t time.Time) { + c.testTickC <- t } func (c *Client) testHook(stage hook, ps *pairedSessions, s *sessionRun, p *peer) { @@ -858,12 +861,23 @@ func (c *Client) epochTicker(ctx context.Context) error { if err != nil { return err } + timerC := time.After(timeoutDuration + 2*time.Second) + testWaiting := c.testWaiting + var testTickC <-chan time.Time + for { + select { + case <-ctx.Done(): + return err - select { - case <-time.After(timeoutDuration + 2*time.Second): - case <-c.testTickC: - case <-ctx.Done(): - return err + case <-timerC: + + case testWaiting <- struct{}{}: + testWaiting = nil + testTickC = c.testTickC + continue + case <-testTickC: + } + break } c.mu.Lock() c.removeUnresponsiveDuringEpoch(prevPRs, uint64(firstEpoch.Unix())) diff --git a/mixing/mixclient/client_test.go b/mixing/mixclient/client_test.go index 1f4d88d0e0..012dece7ab 100644 --- a/mixing/mixclient/client_test.go +++ b/mixing/mixclient/client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 The Decred developers +// Copyright (c) 2024-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -45,7 +45,8 @@ const ( func newTestClient(w *testWallet, logger slog.Logger) *Client { c := NewClient(w) - c.testTickC = make(chan struct{}) + c.testWaiting = make(chan struct{}) + c.testTickC = make(chan time.Time) c.SetLogger(logger) return c } @@ -114,7 +115,7 @@ func (w *testWallet) SignInput(tx *wire.MsgTx, index int, prevScript []byte) err } func (w *testWallet) SubmitMixMessage(ctx context.Context, msg mixing.Message) error { - _, err := w.mixpool.AcceptMessage(msg) + _, err := w.mixpool.AcceptMessage(msg, mixpool.ZeroSource) return err } @@ -223,7 +224,8 @@ func TestHonest(t *testing.T) { <-doneRun }() - c.testTick() + <-c.testWaiting + c.testTick(time.Now().Truncate(time.Second)) var g errgroup.Group for i := range peers { @@ -235,7 +237,8 @@ func TestHonest(t *testing.T) { go func() { for { - c.testTick() + <-c.testWaiting + c.testTick(time.Now().Truncate(time.Second)) select { case <-ctx.Done(): return @@ -329,8 +332,11 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { }() testTick := func() { - c.testTick() - c2.testTick() + <-c.testWaiting + <-c2.testWaiting + t := time.Now().Truncate(time.Second) + c.testTick(t) + c2.testTick(t) } testTick() diff --git a/mixing/mixclient/limits.go b/mixing/mixclient/limits.go index c996eebcde..9c30986edd 100644 --- a/mixing/mixclient/limits.go +++ b/mixing/mixclient/limits.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023-2024 The Decred developers +// Copyright (c) 2023-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -7,6 +7,7 @@ package mixclient import ( "errors" + "github.com/decred/dcrd/mixing" "github.com/decred/dcrd/wire" ) @@ -70,8 +71,6 @@ type coinjoinSize struct { // total transaction size to be too large, even if they have not acted // maliciously in the mixing protocol. func (c *coinjoinSize) join(contributedInputs, mcount int, change *wire.TxOut) error { - const maxStandardSize = 100000 - totalInputs := c.currentInputs + contributedInputs l := len(c.currentOutputScriptSizes) @@ -85,7 +84,7 @@ func (c *coinjoinSize) join(contributedInputs, mcount int, change *wire.TxOut) e c.currentOutputScriptSizes = c.currentOutputScriptSizes[:l] estimated := estimateP2PKHv0SerializeSize(totalInputs, totalOutputScriptSizes) - if estimated >= maxStandardSize { + if estimated >= mixing.MaxMixTxSerializeSize { return errExceedsStandardSize } diff --git a/mixing/mixclient/log_test.go b/mixing/mixclient/log_test.go index a674e8fd92..20bcd660c4 100644 --- a/mixing/mixclient/log_test.go +++ b/mixing/mixclient/log_test.go @@ -32,6 +32,7 @@ func useTestLogger(t *testing.T) (slog.Logger, func()) { l.SetLevel(slog.LevelTrace) mixpool.UseLogger(l) return l, func() { + l.SetLevel(slog.LevelOff) mixpool.UseLogger(slog.Disabled) } } diff --git a/mixing/mixpool/errors.go b/mixing/mixpool/errors.go index 6e342f909a..caf788798a 100644 --- a/mixing/mixpool/errors.go +++ b/mixing/mixpool/errors.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 The Decred developers +// Copyright (c) 2024-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -53,11 +53,20 @@ var ( // change amount is dust. ErrChangeDust = newBannableError("change output is dust", 0) + // ErrMixDust is returned by AcceptMessage if a pair request's mix + // amount value is dust. + ErrMixDust = newBannableError("mix output is dust", 0) + // ErrLowInput is returned by AcceptMessage when not enough input value // is provided by a pair request to cover the mixed output, any change // output, and the minimum required fee. ErrLowInput = newBannableError("not enough input value, or too low fee", 0) + // ErrHighFee is returned by AcceptMessage when a pair request + // indicates contributing too much fee to the mix transaction in a way + // that would be caught by mempool's default high fee policy. + ErrHighFee = newBannableError("too high fee", 0) + // ErrInvalidMessageCount is returned by AcceptMessage if a // pair request contains an invalid message count. ErrInvalidMessageCount = newBannableError("message count must be positive", 0) diff --git a/mixing/mixpool/limits.go b/mixing/mixpool/limits.go new file mode 100644 index 0000000000..7a9befae08 --- /dev/null +++ b/mixing/mixpool/limits.go @@ -0,0 +1,128 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixpool + +import ( + "fmt" + + "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/wire" +) + +func checkPRLimits(pr *wire.MsgMixPairReq) error { + if pr.MixAmount > mixing.MaxMixAmount { + return ruleError(fmt.Errorf("mixed output value %d exceeds max %d", + pr.MixAmount, mixing.MaxMixAmount)) + } + if pr.MessageCount > mixing.MaxMcount { + return ruleError(fmt.Errorf("message count %d exceeds max %d", + pr.MessageCount, mixing.MaxMcount)) + } + + return nil +} + +func checkKELimits(ke *wire.MsgMixKeyExchange) error { + if len(ke.SeenPRs) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced PRs exceeds max peers %d", + len(ke.SeenPRs), mixing.MaxPeers)) + } + + return nil +} + +func checkCTLimits(ct *wire.MsgMixCiphertexts) error { + if len(ct.Ciphertexts) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d ciphertexts exceeds max peers %d", + len(ct.Ciphertexts), mixing.MaxPeers)) + } + if len(ct.SeenKeyExchanges) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced KEs exceeds max peers %d", + len(ct.SeenKeyExchanges), len(ct.Ciphertexts))) + } + + return nil +} + +func checkSRLimits(sr *wire.MsgMixSlotReserve) error { + if len(sr.DCMix) > mixing.MaxMcount { + return ruleError(fmt.Errorf("outer DC-mix dimension size %d exceeds max message count %v", + len(sr.DCMix), mixing.MaxMcount)) + } + for i := range sr.DCMix { + if len(sr.DCMix[i]) > mixing.MaxMtot { + return ruleError(fmt.Errorf("inner DC-mix dimension size %d exceeds max session message total %v", + len(sr.DCMix[i]), mixing.MaxMtot)) + } + } + if len(sr.SeenCiphertexts) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced CTs exceeds max peers %d", + len(sr.SeenCiphertexts), mixing.MaxPeers)) + } + + return nil +} + +func checkDCLimits(dc *wire.MsgMixDCNet) error { + if len(dc.DCNet) > mixing.MaxMcount { + return ruleError(fmt.Errorf("outer DC-net dimension size %d exceeds max message count %v", + len(dc.DCNet), mixing.MaxMcount)) + } + for i := range dc.DCNet { + if len(dc.DCNet[i]) > mixing.MaxMtot { + return ruleError(fmt.Errorf("inner DC-net dimension size %d exceeds max session message total %d", + len(dc.DCNet[i]), mixing.MaxMtot)) + } + } + if len(dc.SeenSlotReserves) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced SRs exceeds max peers %d", + len(dc.SeenSlotReserves), mixing.MaxPeers)) + } + + return nil +} + +func checkCMLimits(cm *wire.MsgMixConfirm) error { + if sz := cm.Mix.SerializeSize(); sz > mixing.MaxMixTxSerializeSize { + return ruleError(fmt.Errorf("mix transaction serialize size %d exceeds maximum size %d", + sz, mixing.MaxMixTxSerializeSize)) + } + if len(cm.SeenDCNets) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced DCs exceeds max peers %d", + len(cm.SeenDCNets), mixing.MaxPeers)) + } + + return nil +} + +func checkFPLimits(fp *wire.MsgMixFactoredPoly) error { + if len(fp.Roots) > mixing.MaxMtot { + return ruleError(fmt.Errorf("%d solved roots exceeds max session message total %d", + len(fp.Roots), mixing.MaxMtot)) + } + if len(fp.SeenSlotReserves) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced SRs exceeds max peers %d", + len(fp.SeenSlotReserves), mixing.MaxPeers)) + } + + return nil +} + +func checkRSLimits(rs *wire.MsgMixSecrets) error { + if len(rs.SlotReserveMsgs) > mixing.MaxMcount { + return ruleError(fmt.Errorf("%d unpadded SR messages exceeds max message count %d", + len(rs.SlotReserveMsgs), mixing.MaxMcount)) + } + if len(rs.DCNetMsgs) > mixing.MaxMcount { + return ruleError(fmt.Errorf("%d unpadded DC messages exceeds max message count %d", + len(rs.DCNetMsgs), mixing.MaxMcount)) + } + if len(rs.SeenSecrets) > mixing.MaxPeers { + return ruleError(fmt.Errorf("%d referenced RSs exceeds max peers %d", + len(rs.SeenSecrets), mixing.MaxPeers)) + } + + return nil +} diff --git a/mixing/mixpool/mixpool.go b/mixing/mixpool/mixpool.go index 98888d0997..694835d5fa 100644 --- a/mixing/mixpool/mixpool.go +++ b/mixing/mixpool/mixpool.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023-2025 The Decred developers +// Copyright (c) 2023-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -10,6 +10,7 @@ import ( "bytes" "context" "fmt" + "slices" "sort" "sync" "time" @@ -28,8 +29,20 @@ import ( const minconf = 1 const feeRate = 0.0001e8 +const maxRelayFeeMultiplier = 1e4 const earlyKEDuration = 5 * time.Second +const ( + // maxOrphans specifies the maximum number of orphans allowed in the orphan + // pool at one time. + maxOrphans = 250 + + // maxPostEvictionOrphans is the maximum number of orphans to keep in the + // pool after a forced eviction occurs due to exceeding the overall max + // limit. It is set to 75% of the overall max limit. + maxPostEvictionOrphans = maxOrphans * 3 / 4 +) + type idPubKey = [33]byte type msgtype int @@ -85,6 +98,27 @@ func (m msgtype) String() string { } } +// Source represents a source of mixing messages. This is typically the peer +// that first relayed them, but the caller may choose any scheme it desires. +type Source interface { + // ID returns an opaque identifier that uniquely identifies the source. + ID() uint64 +} + +// Uint64Source implements the [Source] interface by returning the associated +// uint64 as the ID. This is primarily useful as a convenience for callers that +// do not require an additional object associated with the source. +type Uint64Source uint64 + +// ID returns the underlying uint64 associated with the source. +func (s Uint64Source) ID() uint64 { return uint64(s) } + +// Ensure [Uint64Source] implements the [Source] interface. +var _ Source = (*Uint64Source)(nil) + +// ZeroSource implements the [Source] interface by returning 0 for the ID. +const ZeroSource = Uint64Source(0) + // entry describes non-PR messages accepted to the pool. type entry struct { hash chainhash.Hash @@ -94,8 +128,9 @@ type entry struct { msgtype msgtype } -type orphan struct { +type orphanMsg struct { message mixing.Message + src Source accepted time.Time } @@ -145,8 +180,8 @@ type Pool struct { prs map[chainhash.Hash]*wire.MsgMixPairReq outPoints map[wire.OutPoint]chainhash.Hash pool map[chainhash.Hash]entry - orphans map[chainhash.Hash]*orphan - orphansByID map[idPubKey]map[chainhash.Hash]mixing.Message + orphans map[chainhash.Hash]*orphanMsg + orphansByID map[idPubKey]map[chainhash.Hash]*orphanMsg messagesByIdentity map[idPubKey][]chainhash.Hash latestKE map[idPubKey]*wire.MsgMixKeyExchange sessions map[[32]byte]*session @@ -228,8 +263,8 @@ func NewPool(blockchain BlockChain) *Pool { prs: make(map[chainhash.Hash]*wire.MsgMixPairReq), outPoints: make(map[wire.OutPoint]chainhash.Hash), pool: make(map[chainhash.Hash]entry), - orphans: make(map[chainhash.Hash]*orphan), - orphansByID: make(map[idPubKey]map[chainhash.Hash]mixing.Message), + orphans: make(map[chainhash.Hash]*orphanMsg), + orphansByID: make(map[idPubKey]map[chainhash.Hash]*orphanMsg), messagesByIdentity: make(map[idPubKey][]chainhash.Hash), latestKE: make(map[idPubKey]*wire.MsgMixKeyExchange), sessions: make(map[[32]byte]*session), @@ -474,6 +509,46 @@ func (p *Pool) removeMessage(hash chainhash.Hash) { p.maybeLogRecentMixMsgsNumEvicted() } +// removeOrphan removes the message associated with the passed hash from the +// orphan pool and orphans by ID index. +// +// This function MUST be called with the mixpool lock held (for writes). +func (p *Pool) removeOrphan(hash *chainhash.Hash, id *idPubKey) { + // Remove the message from the orphan pool and the reference from the + // orphans by ID index. + delete(p.orphans, *hash) + orphansByID := p.orphansByID[*id] + delete(orphansByID, *hash) + + // Remove the map entry altogether if there are no longer any orphans which + // depend on it. + if len(orphansByID) == 0 { + delete(p.orphansByID, *id) + } + + log.Tracef("Removed orphan %v (pool size %v)", hash, len(p.orphans)) +} + +// removeOrphansBySourceID removes up to the maximum specified number of orphan +// messages associated with the provided source ID and returns the number of +// entries removed. + +// This function MUST be called with the mixpool lock held (for writes). +func (p *Pool) removeOrphansBySourceID(srcID uint64, maxToEvict uint64) uint64 { + var numEvicted uint64 + for hash, orphan := range p.orphans { + if numEvicted >= maxToEvict { + break + } + if orphan.src.ID() == srcID { + id := (*idPubKey)(orphan.message.Pub()) + p.removeOrphan(&hash, id) + numEvicted++ + } + } + return numEvicted +} + // ExpireMessages immediately expires all pair requests and sessions built // from them that indicate expiry at or after a block height. func (p *Pool) ExpireMessages(height uint32) { @@ -517,8 +592,7 @@ func (p *Pool) expireMessagesNow(height uint32) { } } if expire { - delete(p.orphans, hash) - delete(p.orphansByID, *(*idPubKey)(o.message.Pub())) + p.removeOrphan(&hash, (*idPubKey)(o.message.Pub())) } } } @@ -969,6 +1043,88 @@ Loop: var zeroHash chainhash.Hash +// limitNumOrphans limits the number of orphan mixing messages by evicting a +// subset of the existing orphans when adding a new one would cause it to +// overflow the max allowed. +// +// This function MUST be called with the mixpool lock held (for writes). +func (p *Pool) limitNumOrphans() { + // Nothing to do if adding another orphan will not cause the pool to exceed + // the limit. + if len(p.orphans)+1 <= maxOrphans { + return + } + + // Determine which sources have the most orphans associated with them and + // then remove all of the orphans associated with each source in descending + // order until the orphan pool has reached the target maximum number of post + // eviction orphans allowed. + // + // This approach is fairly efficient since it naturally limits the frequency + // of eviction algorithm execution. Further, in practice, orphan messages + // are quite rare after initial startup where ongoing mixing sessions are + // discovered, so any peer sending a lot of orphans is likely experiencing + // severe connectivity issues or otherwise misbehaving. This approach also + // has the added benefit of handling a variety of orphan flooding + // misbehavior well. + srcCounters := make(map[uint64]int) + for _, orphan := range p.orphans { + srcCounters[orphan.src.ID()]++ + } + type srcWithCount struct { + srcID uint64 + count int + } + srcCounts := make([]srcWithCount, 0, len(srcCounters)) + for srcID, count := range srcCounters { + srcCounts = append(srcCounts, srcWithCount{srcID, count}) + } + slices.SortFunc(srcCounts, func(a, b srcWithCount) int { + return b.count - a.count + }) + numOrphans := uint64(len(p.orphans)) + for numOrphans > maxPostEvictionOrphans && len(srcCounts) > 0 { + srcID := srcCounts[0].srcID + maxToEvict := numOrphans - maxPostEvictionOrphans + numEvicted := p.removeOrphansBySourceID(srcID, maxToEvict) + log.Tracef("Removed %d orphans with source ID %d", numEvicted, srcID) + numOrphans -= numEvicted + srcCounts = srcCounts[1:] + } +} + +// addOrphan adds the passed message to the orphan pool when it is not already +// present. +// +// It also potentially removes orphans to make room when necessary. +// +// This function MUST be called with the mixpool lock held (for writes). +func (p *Pool) addOrphan(msg mixing.Message, hash *chainhash.Hash, id *idPubKey, src Source) { + orphansByID := p.orphansByID[*id] + if _, ok := orphansByID[*hash]; ok { + // Already an orphan. + return + } + + // Limit the number of orphan mixing messages to prevent memory exhaustion. + p.limitNumOrphans() + + orphan := &orphanMsg{ + message: msg, + src: src, + accepted: time.Now(), + } + p.orphans[*hash] = orphan + if orphansByID == nil { + orphansByID = make(map[chainhash.Hash]*orphanMsg) + p.orphansByID[*id] = orphansByID + } + orphansByID[*hash] = orphan + + log.Debugf("Stored orphan message %T %v (pool size: %d)", msg, hash, + len(p.orphans)) +} + // AcceptMessage accepts a mixing message to the pool. // // Messages must contain the mixing participant's identity and contain a valid @@ -980,7 +1136,7 @@ var zeroHash chainhash.Hash // // All newly accepted messages, including any orphan key exchange messages // that were processed after processing missing pair requests, are returned. -func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err error) { +func (p *Pool) AcceptMessage(msg mixing.Message, src Source) (accepted []mixing.Message, err error) { defer func() { if err == nil && len(accepted) == 0 { // Don't log duplicate messages or non-KE orphans. @@ -1075,7 +1231,7 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err p.mtx.Lock() defer p.mtx.Unlock() - accepted, err := p.acceptKE(msg, &hash, id) + accepted, err := p.acceptKE(msg, &hash, id, src) if err != nil { return nil, err } @@ -1088,16 +1244,34 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err return allAccepted, nil case *wire.MsgMixCiphertexts: + if err := checkCTLimits(msg); err != nil { + return nil, err + } msgtype = msgtypeCT case *wire.MsgMixSlotReserve: + if err := checkSRLimits(msg); err != nil { + return nil, err + } msgtype = msgtypeSR case *wire.MsgMixDCNet: + if err := checkDCLimits(msg); err != nil { + return nil, err + } msgtype = msgtypeDC case *wire.MsgMixConfirm: + if err := checkCMLimits(msg); err != nil { + return nil, err + } msgtype = msgtypeCM case *wire.MsgMixFactoredPoly: + if err := checkFPLimits(msg); err != nil { + return nil, err + } msgtype = msgtypeFP case *wire.MsgMixSecrets: + if err := checkRSLimits(msg); err != nil { + return nil, err + } msgtype = msgtypeRS default: return nil, fmt.Errorf("unknown mix message type %T", msg) @@ -1133,20 +1307,7 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err } // Save as an orphan if their KE is not (yet) accepted. if !haveKE { - orphansByID := p.orphansByID[*id] - if _, ok := orphansByID[hash]; ok { - // Already an orphan. - return nil, nil - } - if orphansByID == nil { - orphansByID = make(map[chainhash.Hash]mixing.Message) - p.orphansByID[*id] = orphansByID - } - p.orphans[hash] = &orphan{ - message: msg, - accepted: time.Now(), - } - orphansByID[hash] = msg + p.addOrphan(msg, &hash, id, src) // TODO: Consider return an error containing the unknown // messages, so they can be getdata'd. @@ -1197,30 +1358,46 @@ func (p *Pool) removePR(pr *wire.MsgMixPairReq, reason string) { delete(p.messagesByIdentity, pr.Identity) delete(p.latestKE, pr.Identity) for orphanHash := range p.orphansByID[pr.Identity] { - delete(p.orphans, orphanHash) + p.removeOrphan(&orphanHash, &pr.Identity) } - delete(p.orphansByID, pr.Identity) for i := range pr.UTXOs { delete(p.outPoints, pr.UTXOs[i].OutPoint) } } func (p *Pool) checkAcceptPR(pr *wire.MsgMixPairReq) error { + if err := checkPRLimits(pr); err != nil { + return err + } + + inputValue := pr.InputValue + if pr.Change != nil { + if pr.Change.Value < 0 || isDustAmount(pr.Change.Value, p2pkhv0PkScriptSize, feeRate) { + return ruleError(ErrChangeDust) + } + + if pr.Change.Value > inputValue { + return ruleError(ErrLowInput) + } + inputValue -= pr.Change.Value + + if pr.Change.Version != 0 { + return ruleError(fmt.Errorf("unrecognized script version")) + } + if pr.Change.Version == 0 && (!stdscript.IsPubKeyHashScriptV0(pr.Change.PkScript) && + !stdscript.IsScriptHashScriptV0(pr.Change.PkScript)) { + return ruleError(ErrInvalidScript) + } + } switch { case len(pr.UTXOs) == 0: // Require at least one utxo. return ruleError(ErrMissingUTXOs) case pr.MessageCount == 0: // Require at least one mixed message. return ruleError(ErrInvalidMessageCount) - case pr.InputValue < int64(pr.MessageCount)*pr.MixAmount: + case isDustAmount(pr.MixAmount, p2pkhv0PkScriptSize, feeRate): + return ruleError(ErrMixDust) + case inputValue < int64(pr.MessageCount)*pr.MixAmount: return ruleError(ErrInvalidTotalMixAmount) - case pr.Change != nil: - if isDustAmount(pr.Change.Value, p2pkhv0PkScriptSize, feeRate) { - return ruleError(ErrChangeDust) - } - if !stdscript.IsPubKeyHashScriptV0(pr.Change.PkScript) && - !stdscript.IsScriptHashScriptV0(pr.Change.PkScript) { - return ruleError(ErrInvalidScript) - } } // Check that expiry has not been reached, nor that it is too far @@ -1247,13 +1424,68 @@ func (p *Pool) checkAcceptPR(pr *wire.MsgMixPairReq) error { return err } - // If able, sanity check UTXOs. - if p.utxoFetcher != nil { - err := p.checkUTXOs(pr, curHeight) - if err != nil { - return err + // Check that UTXOs exist, have confirmations, sum of UTXO values matches the + // input value, and proof of ownership is valid. + var totalValue int64 + outpoints := make(map[wire.OutPoint]struct{}) + for i := range pr.UTXOs { + utxo := &pr.UTXOs[i] + + if _, ok := outpoints[utxo.OutPoint]; ok { + return ruleError(ErrInvalidUTXOProof) + } + outpoints[utxo.OutPoint] = struct{}{} + + if utxo.Script != nil { + return ruleError(fmt.Errorf("P2SH inputs are unsupported")) + } + + if p.utxoFetcher != nil { + entry, err := p.utxoFetcher.FetchUtxoEntry(utxo.OutPoint) + if err != nil { + return err + } + if entry == nil || entry.IsSpent() { + return ruleError(fmt.Errorf("output %v is not unspent", + &utxo.OutPoint)) + } + height := entry.BlockHeight() + if !confirmed(minconf, height, curHeight) { + return ruleError(fmt.Errorf("output %v is unconfirmed", + &utxo.OutPoint)) + } + if entry.ScriptVersion() != 0 { + return ruleError(fmt.Errorf("output %v does not use script version 0", + &utxo.OutPoint)) + } + + // Check proof of key ownership and ability to sign coinjoin + // inputs. + var extractPubKeyHash160 func([]byte) []byte + switch utxo.Opcode { + case 0: + extractPubKeyHash160 = stdscript.ExtractPubKeyHashV0 + case txscript.OP_SSGEN: + extractPubKeyHash160 = stdscript.ExtractStakeGenPubKeyHashV0 + case txscript.OP_SSRTX: + extractPubKeyHash160 = stdscript.ExtractStakeRevocationPubKeyHashV0 + case txscript.OP_TGEN: + extractPubKeyHash160 = stdscript.ExtractTreasuryGenPubKeyHashV0 + default: + return ruleError(fmt.Errorf("unsupported output script for UTXO %s", &utxo.OutPoint)) + } + valid := validateOwnerProofP2PKHv0(extractPubKeyHash160, + entry.PkScript(), utxo.PubKey, utxo.Signature, pr.Expires()) + if !valid { + return ruleError(ErrInvalidUTXOProof) + } + + totalValue += entry.Amount() } } + if totalValue != 0 && totalValue != pr.InputValue { + return ruleError(ErrInvalidUTXOProof) + } return nil } @@ -1317,42 +1549,35 @@ func (p *Pool) reconsiderOrphans(accepted mixing.Message, id *idPubKey) []mixing // If the accepted message was a PR, there may be KE orphans that can // be accepted now. if pr, ok := accepted.(*wire.MsgMixPairReq); ok { - var orphanKEs []*wire.MsgMixKeyExchange + var orphanKEs []*orphanMsg for _, orphan := range p.orphansByID[*id] { - orphanKE, ok := orphan.(*wire.MsgMixKeyExchange) + orphanKE, ok := orphan.message.(*wire.MsgMixKeyExchange) if !ok { continue } - refsAcceptedPR := false - for _, prHash := range orphanKE.SeenPRs { - if pr.Hash() == prHash { - refsAcceptedPR = true - break - } - } - if !refsAcceptedPR { + if !slices.Contains(orphanKE.SeenPRs, pr.Hash()) { continue } - orphanKEs = append(orphanKEs, orphanKE) + orphanKEs = append(orphanKEs, orphan) } - for _, orphanKE := range orphanKEs { + for _, orphan := range orphanKEs { + orphanKE := orphan.message.(*wire.MsgMixKeyExchange) orphanKEHash := orphanKE.Hash() - _, err := p.acceptKE(orphanKE, &orphanKEHash, &orphanKE.Identity) + _, err := p.acceptKE(orphanKE, &orphanKEHash, &orphanKE.Identity, + orphan.src) if err != nil { log.Debugf("orphan KE could not be accepted: %v", err) continue } kes = append(kes, orphanKE) - delete(p.orphansByID[*id], orphanKEHash) - delete(p.orphans, orphanKEHash) + p.removeOrphan(&orphanKEHash, id) acceptedMessages = append(acceptedMessages, orphanKE) } if len(p.orphansByID[*id]) == 0 { - delete(p.orphansByID, *id) return acceptedMessages } } @@ -1368,8 +1593,8 @@ func (p *Pool) reconsiderOrphans(accepted mixing.Message, id *idPubKey) []mixing continue } - var acceptedOrphans []mixing.Message - for orphanHash, orphan := range p.orphansByID[*id] { + for orphanHash, omsg := range p.orphansByID[*id] { + orphan := omsg.message if !bytes.Equal(orphan.Sid(), ke.SessionID[:]) { continue } @@ -1395,16 +1620,10 @@ func (p *Pool) reconsiderOrphans(accepted mixing.Message, id *idPubKey) []mixing p.acceptEntry(orphan, msgtype, &orphanHash, id, ses) - acceptedOrphans = append(acceptedOrphans, orphan) acceptedMessages = append(acceptedMessages, orphan) - } - for _, orphan := range acceptedOrphans { - orphanHash := orphan.Hash() - delete(p.orphansByID[*id], orphanHash) - delete(p.orphans, orphanHash) + p.removeOrphan(&orphanHash, id) } if len(p.orphansByID[*id]) == 0 { - delete(p.orphansByID, *id) return acceptedMessages } } @@ -1412,63 +1631,6 @@ func (p *Pool) reconsiderOrphans(accepted mixing.Message, id *idPubKey) []mixing return acceptedMessages } -// Check that UTXOs exist, have confirmations, sum of UTXO values matches the -// input value, and proof of ownership is valid. -func (p *Pool) checkUTXOs(pr *wire.MsgMixPairReq, curHeight int64) error { - var totalValue int64 - - for i := range pr.UTXOs { - utxo := &pr.UTXOs[i] - entry, err := p.utxoFetcher.FetchUtxoEntry(utxo.OutPoint) - if err != nil { - return err - } - if entry == nil || entry.IsSpent() { - return ruleError(fmt.Errorf("output %v is not unspent", - &utxo.OutPoint)) - } - height := entry.BlockHeight() - if !confirmed(minconf, height, curHeight) { - return ruleError(fmt.Errorf("output %v is unconfirmed", - &utxo.OutPoint)) - } - if entry.ScriptVersion() != 0 { - return ruleError(fmt.Errorf("output %v does not use script version 0", - &utxo.OutPoint)) - } - - // Check proof of key ownership and ability to sign coinjoin - // inputs. - var extractPubKeyHash160 func([]byte) []byte - switch utxo.Opcode { - case 0: - extractPubKeyHash160 = stdscript.ExtractPubKeyHashV0 - case txscript.OP_SSGEN: - extractPubKeyHash160 = stdscript.ExtractStakeGenPubKeyHashV0 - case txscript.OP_SSRTX: - extractPubKeyHash160 = stdscript.ExtractStakeRevocationPubKeyHashV0 - case txscript.OP_TGEN: - extractPubKeyHash160 = stdscript.ExtractTreasuryGenPubKeyHashV0 - default: - return ruleError(fmt.Errorf("unsupported output script for UTXO %s", &utxo.OutPoint)) - } - valid := validateOwnerProofP2PKHv0(extractPubKeyHash160, - entry.PkScript(), utxo.PubKey, utxo.Signature, pr.Expires()) - if !valid { - return ruleError(ErrInvalidUTXOProof) - } - - totalValue += entry.Amount() - } - - if totalValue != pr.InputValue { - return ruleError(fmt.Errorf("input value does not match sum of UTXO " + - "values")) - } - - return nil -} - func validateOwnerProofP2PKHv0(extractFunc func([]byte) []byte, pkscript, pubkey, sig []byte, expires uint32) bool { extractedHash160 := extractFunc(pkscript) pubkeyHash160 := stdaddr.Hash160(pubkey) @@ -1480,6 +1642,10 @@ func validateOwnerProofP2PKHv0(extractFunc func([]byte) []byte, pkscript, pubkey } func (p *Pool) checkAcceptKE(ke *wire.MsgMixKeyExchange) error { + if err := checkKELimits(ke); err != nil { + return err + } + // Validate PR order and session ID. if err := mixing.ValidateSession(ke); err != nil { return ruleError(err) @@ -1499,7 +1665,7 @@ func (p *Pool) checkAcceptKE(ke *wire.MsgMixKeyExchange) error { return nil } -func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixKeyExchange, err error) { +func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *idPubKey, src Source) (accepted *wire.MsgMixKeyExchange, err error) { // Check if already accepted. if _, ok := p.pool[*hash]; ok { return nil, nil @@ -1551,16 +1717,7 @@ func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *id } } if missingOwnPR != nil { - p.orphans[*hash] = &orphan{ - message: ke, - accepted: time.Now(), - } - orphansByID := p.orphansByID[*id] - if orphansByID == nil { - orphansByID = make(map[chainhash.Hash]mixing.Message) - p.orphansByID[*id] = orphansByID - } - orphansByID[*hash] = ke + p.addOrphan(ke, hash, id, src) err := &MissingOwnPRError{ MissingPR: *missingOwnPR, } @@ -1659,6 +1816,10 @@ func checkFee(pr *wire.MsgMixPairReq, feeRate int64) error { if fee < requiredFee { return ruleError(ErrLowInput) } + maxFee := feeForSerializeSize(feeRate, estimatedSize*maxRelayFeeMultiplier) + if fee > maxFee { + return ruleError(ErrHighFee) + } return nil } diff --git a/mixing/mixpool/mixpool_test.go b/mixing/mixpool/mixpool_test.go index bdfb5a814c..b08e84440d 100644 --- a/mixing/mixpool/mixpool_test.go +++ b/mixing/mixpool/mixpool_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023-2024 The Decred developers +// Copyright (c) 2023-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -13,6 +13,9 @@ import ( "io" "math/big" "os" + "strconv" + "strings" + "sync/atomic" "testing" "time" @@ -34,9 +37,18 @@ var params = chaincfg.SimNetParams() var seed [32]byte +var prngNonce atomic.Uint32 + func testPRNG(t *testing.T) *chacha20prng.Reader { - t.Logf("PRNG seed: %x\n", seed) - return chacha20prng.New(seed[:], 0) + t.Helper() + n := prngNonce.Add(1) - 1 + t.Cleanup(func() { + t.Helper() + if t.Failed() { + t.Logf("Reproduce with -seed=%x/%d", seed[:], n) + } + }) + return chacha20prng.New(seed[:], n) } var utxoStore struct { @@ -48,16 +60,23 @@ func TestMain(m *testing.M) { seedFlag := flag.String("seed", "", "use deterministic PRNG seed (32 bytes, hex)") flag.Parse() if *seedFlag != "" { - b, err := hex.DecodeString(*seedFlag) + slash := strings.IndexByte(*seedFlag, '/') + if slash != 64 { + fmt.Fprintln(os.Stderr, "invalid -seed: must be in form <32 byte hex seed>/iteration") + os.Exit(1) + } + b, err := hex.DecodeString((*seedFlag)[:slash]) if err != nil { fmt.Fprintln(os.Stderr, "invalid -seed:", err) os.Exit(1) } - if len(b) != 32 { - fmt.Fprintln(os.Stderr, "invalid -seed: must be 32 bytes") + copy(seed[:], b) + nonce, err := strconv.ParseUint((*seedFlag)[slash+1:], 10, 32) + if err != nil { + fmt.Fprintln(os.Stderr, "invalid -seed nonce:", err) os.Exit(1) } - copy(seed[:], b) + prngNonce.Store(uint32(nonce)) } else { cryptorand.Read(seed[:]) } @@ -200,7 +219,7 @@ func TestAccept(t *testing.T) { var ( expires uint32 = 1010 - mixAmount int64 = 10e8 + mixAmount int64 = 20e8 - 3000 scriptClass = mixing.ScriptClassP2PKHv0 txVersion uint16 = wire.TxVersion lockTime uint32 = 0 @@ -225,7 +244,7 @@ func TestAccept(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = p.AcceptMessage(pr) + _, err = p.AcceptMessage(pr, ZeroSource) if err != nil { t.Fatal(err) } @@ -271,7 +290,7 @@ func TestAccept(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = p.AcceptMessage(ke) + _, err = p.AcceptMessage(ke, ZeroSource) if err != nil { t.Fatal(err) } @@ -289,7 +308,7 @@ func TestAccept(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = p.AcceptMessage(ct) + _, err = p.AcceptMessage(ct, ZeroSource) if err != nil { t.Fatal(err) } @@ -330,7 +349,7 @@ func TestAccept(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = p.AcceptMessage(sr) + _, err = p.AcceptMessage(sr, ZeroSource) if err != nil { t.Fatal(err) } @@ -365,7 +384,7 @@ func TestAccept(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = p.AcceptMessage(dc) + _, err = p.AcceptMessage(dc, ZeroSource) if err != nil { t.Fatal(err) } @@ -399,7 +418,7 @@ func TestAccept(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = p.AcceptMessage(cm) + _, err = p.AcceptMessage(cm, ZeroSource) if err != nil { t.Fatal(err) } diff --git a/mixing/mixpool/orphans_test.go b/mixing/mixpool/orphans_test.go index 48ffd7917f..d54b34a19d 100644 --- a/mixing/mixpool/orphans_test.go +++ b/mixing/mixpool/orphans_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 The Decred developers +// Copyright (c) 2024-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -55,9 +55,10 @@ func TestOrphans(t *testing.T) { {}, }, MessageCount: 1, + MixAmount: 1 << 18, Expiry: testStartingHeight + 10, ScriptClass: string(mixing.ScriptClassP2PKHv0), - InputValue: 1 << 18, + InputValue: 1<<18 + 3000, } err = mixing.SignMessage(pr, priv) if err != nil { @@ -153,7 +154,7 @@ func TestOrphans(t *testing.T) { mp := NewPool(newTestBlockchain()) for j, a := range accepts { - accepted, err := mp.AcceptMessage(a.message) + accepted, err := mp.AcceptMessage(a.message, ZeroSource) if err != nil != a.errors { t.Errorf("test %d call %d %q: unexpected error: %v", i, j, a.desc, err) } @@ -173,3 +174,117 @@ func TestOrphans(t *testing.T) { } } } + +// TestOrphanEviction ensures that exceeding the maximum number of orphans +// evicts entries to make room for the new ones. +func TestOrphanEviction(t *testing.T) { + pub, priv, err := generateSecp256k1(nil) + if err != nil { + t.Fatal(err) + } + id := *(*[33]byte)(pub.SerializeCompressed()) + + // createOrphanKEAndCT returns a signed mix key exchange and cyphertexts + // message associated with a unique pair request. + h := blake256.NewHasher256() + var nextPRNum uint32 + createOrphanKEAndCT := func() (*wire.MsgMixKeyExchange, *wire.MsgMixCiphertexts) { + h.Reset() + h.WriteUint32LE(nextPRNum) + nextPRNum++ + randPrevOut := wire.OutPoint{Hash: h.Sum256()} + randomUTXO := wire.MixPairReqUTXO{OutPoint: randPrevOut} + pr := &wire.MsgMixPairReq{ + Identity: id, + UTXOs: []wire.MixPairReqUTXO{randomUTXO}, + MessageCount: 1, + Expiry: testStartingHeight + 10, + ScriptClass: string(mixing.ScriptClassP2PKHv0), + InputValue: 1 << 18, + } + err = mixing.SignMessage(pr, priv) + if err != nil { + t.Fatal(err) + } + pr.WriteHash(h) + + prs := []*wire.MsgMixPairReq{pr} + epoch := uint64(time.Now().Unix()) + sid := mixing.SortPRsForSession(prs, epoch) + ke := &wire.MsgMixKeyExchange{ + Identity: id, + SessionID: sid, + Epoch: epoch, + Run: 0, + SeenPRs: []chainhash.Hash{ + pr.Hash(), + }, + } + err = mixing.SignMessage(ke, priv) + if err != nil { + t.Fatal(err) + } + ke.WriteHash(h) + + seenKEs := []chainhash.Hash{ke.Hash()} + ct := &wire.MsgMixCiphertexts{ + Identity: id, + SessionID: sid, + Run: 0, + SeenKeyExchanges: seenKEs, + } + err = mixing.SignMessage(ct, priv) + if err != nil { + t.Fatal(err) + } + ct.WriteHash(h) + + return ke, ct + } + + // Create enough orphan key exchange and cyphertext messages to be able to + // exceed the maximum number of orphans allowed in the pool. + orphans := make([]mixing.Message, 0, maxOrphans+2) + for range maxOrphans/2 + 1 { + ke, ct := createOrphanKEAndCT() + orphans = append(orphans, ke, ct) + } + + // Fill up the orphan pool to the max allowed with orphan messages while + // also pretending as though the orphans came from different sources. + p := NewPool(newTestBlockchain()) + acceptOrphanAndCheck := func(msg mixing.Message, srcID uint64) { + accepted, err := p.AcceptMessage(msg, Uint64Source(srcID)) + if len(accepted) != 0 { + t.Fatalf("accepted orphan message %T to main pool", msg) + } + if _, ok := msg.(*wire.MsgMixKeyExchange); ok { + var ownPRErr *MissingOwnPRError + if !errors.As(err, &ownPRErr) { + t.Fatalf("unexpected error for orphan message: %v", err) + } + } + } + const numSources = 5 + numOrphansPerSource := (maxOrphans + 1 + (numSources - 1)) / numSources + for i := range maxOrphans { + orphan := orphans[i] + srcID := uint64(i/numOrphansPerSource) + 1 + acceptOrphanAndCheck(orphan, srcID) + } + + // Ensure the orphan pool is the max allowed. + assertOrphanPoolSize := func(expected int) { + if len(p.orphans) != expected { + t.Fatalf("unexpected orphan pool size - got %d, want %d", + len(p.orphans), expected) + } + } + assertOrphanPoolSize(maxOrphans) + + // Ensure adding another orphan causes evictions of orphans tagged from + // the sources that submitted the most until the pool only retains the + // target number post eviction plus one for the orphan being added. + acceptOrphanAndCheck(orphans[maxOrphans], numSources) + assertOrphanPoolSize(maxPostEvictionOrphans + 1) +} diff --git a/mixing/sid.go b/mixing/sid.go index 7394bf60cd..1dd384f2d6 100644 --- a/mixing/sid.go +++ b/mixing/sid.go @@ -17,7 +17,7 @@ import ( // deriveSessionID creates the mix session identifier from an initial sorted // slice of PR message hashes. func deriveSessionID(seenPRs []chainhash.Hash, epoch uint64) [32]byte { - h := blake256.New() + h := blake256.NewHasher256() buf := make([]byte, 8) h.Write([]byte("decred-mix-session")) @@ -29,7 +29,7 @@ func deriveSessionID(seenPRs []chainhash.Hash, epoch uint64) [32]byte { h.Write(seenPRs[i][:]) } - return *(*[32]byte)(h.Sum(nil)) + return h.Sum256() } // SortPRsForSession performs an in-place sort of prs, moving each pair diff --git a/mixing/signatures.go b/mixing/signatures.go index 190a06c28d..313435d8e1 100644 --- a/mixing/signatures.go +++ b/mixing/signatures.go @@ -1,18 +1,17 @@ -// Copyright (c) 2023-2024 The Decred developers +// Copyright (c) 2023-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package mixing import ( - "bytes" - "fmt" + "encoding/hex" "hash" + "strconv" "github.com/decred/dcrd/crypto/blake256" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4/schnorr" - "github.com/decred/dcrd/wire" ) const tag = "decred-mix-signature" @@ -42,11 +41,9 @@ func SignMessage(m Signed, priv *secp256k1.PrivateKey) error { // VerifySignedMessage verifies that a signed message carries a valid // signature for the represented identity. func VerifySignedMessage(m Signed) bool { - h := blake256.New() + h := blake256.NewHasher256() m.WriteSignedData(h) - sigHash := h.Sum(nil) - - h.Reset() + sigHash := h.Sum256() command := m.Command() sid := m.Sid() @@ -56,7 +53,8 @@ func VerifySignedMessage(m Signed) bool { run = 0 } - return verify(h, m.Pub(), m.Sig(), sigHash, command, sid, run) + h.Reset() + return verify(h, m.Pub(), m.Sig(), sigHash[:], command, sid, run) } // VerifySignature verifies a message signature from its signature hash and @@ -65,18 +63,16 @@ func VerifySignedMessage(m Signed) bool { // the same public key, and demonstrating this can be used to prove malicious // behavior by sending different versions of messages through the network. func VerifySignature(pub, sig, sigHash []byte, command string, sid []byte, run uint32) bool { - h := blake256.New() + h := blake256.NewHasher256() return verify(h, pub, sig, sigHash, command, sid, run) } var zeroSID [32]byte func sign(priv *secp256k1.PrivateKey, m Signed) ([]byte, error) { - h := blake256.New() + h := blake256.NewHasher256() m.WriteSignedData(h) - sigHash := h.Sum(nil) - - h.Reset() + sigHash := h.Sum256() sid := m.Sid() run := m.GetRun() @@ -85,24 +81,16 @@ func sign(priv *secp256k1.PrivateKey, m Signed) ([]byte, error) { run = 0 } - buf := new(bytes.Buffer) - buf.Grow(len(tag) + wire.CommandSize + - 64 + // sid - 4 + // run - 64 + // sigHash - 4, // commas - ) - fmt.Fprintf(buf, tag+",%s,%x,%d,%x", m.Command(), sid, run, sigHash) - h.Write(buf.Bytes()) - - sig, err := schnorr.Sign(priv, h.Sum(nil)) + h.Reset() + hash := schnorrHash(h, m.Command(), sid, run, sigHash[:]) + sig, err := schnorr.Sign(priv, hash[:]) if err != nil { return nil, err } return sig.Serialize(), nil } -func verify(h hash.Hash, pk []byte, sig []byte, sigHash []byte, command string, sid []byte, run uint32) bool { +func verify(h *blake256.Hasher256, pk []byte, sig []byte, sigHash []byte, command string, sid []byte, run uint32) bool { if len(pk) != secp256k1.PubKeyBytesLenCompressed { return false } @@ -115,16 +103,21 @@ func verify(h hash.Hash, pk []byte, sig []byte, sigHash []byte, command string, return false } - h.Reset() + hash := schnorrHash(h, command, sid, run, sigHash) + return sigParsed.Verify(hash[:], pkParsed) +} - buf := new(bytes.Buffer) - buf.Grow(len(tag) + wire.CommandSize + - 64 + // sid - 4 + // run - 64 + // sigHash - 4, // commas - ) - fmt.Fprintf(buf, tag+",%s,%x,%d,%x", command, sid, run, sigHash) - h.Write(buf.Bytes()) - return sigParsed.Verify(h.Sum(nil), pkParsed) +func schnorrHash(h *blake256.Hasher256, command string, sid []byte, run uint32, sigHash []byte) [32]byte { + buf := make([]byte, 64) + + h.WriteBytes([]byte(tag)) + h.WriteByte(',') + h.WriteString(command) + h.WriteByte(',') + h.WriteBytes(hex.AppendEncode(buf[:0], sid)) + h.WriteByte(',') + h.WriteBytes(strconv.AppendUint(buf[:0], uint64(run), 10)) + h.WriteByte(',') + h.WriteBytes(hex.AppendEncode(buf[:0], sigHash)) + return h.Sum256() } diff --git a/mixing/signatures_test.go b/mixing/signatures_test.go new file mode 100644 index 0000000000..4afb518e97 --- /dev/null +++ b/mixing/signatures_test.go @@ -0,0 +1,104 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixing + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/decred/dcrd/wire" +) + +var ( + testPrivKey = secp256k1.PrivKeyFromBytes([]byte{31: 1}) + testPubKey = testPrivKey.PubKey() +) + +func hexDecode(tb testing.TB, s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + tb.Fatalf("Hex decode failed: %v", err) + } + return b +} + +func fakePR(tb testing.TB) *wire.MsgMixPairReq { + var id [33]byte + copy(id[:], testPubKey.SerializeCompressed()) + utxos := []wire.MixPairReqUTXO{{ + OutPoint: wire.OutPoint{ + Hash: [32]byte{31: 1}, + Index: 2, + Tree: 3, + }, + Script: []byte{4, 5, 6}, + PubKey: []byte{31: 7}, + Signature: []byte{31: 8}, + Opcode: 9, + }} + change := &wire.TxOut{ + Value: 10, + Version: 11, + PkScript: []byte{12, 13, 14}, + } + pr, err := wire.NewMsgMixPairReq(id, 15, 16, "script class", 17, 18, 19, 20, utxos, change, 21, 22) + if err != nil { + tb.Fatal(err) + } + return pr +} + +const wantSig = "1ffdddfb26b2955b3cf46c9e98671d11d1fd91619da4e85edb840786f661b6ab6824c83923f596a990613fce97d995d9d50a89ecb23841c9c10a80a47f18b50a" + +func TestSignMessage(t *testing.T) { + pr := fakePR(t) + + if err := SignMessage(pr, testPrivKey); err != nil { + t.Fatalf("Failed to sign message: %v", err) + } + + if !bytes.Equal(pr.Signature[:], hexDecode(t, wantSig)) { + t.Fatalf("Signature does not match expected (got: %x want: %s)", pr.Signature[:], wantSig) + } +} + +func TestVerifySignedMessage(t *testing.T) { + pr := fakePR(t) + + if err := SignMessage(pr, testPrivKey); err != nil { + t.Fatalf("Failed to sign message: %v", err) + } + + if !VerifySignedMessage(pr) { + t.Fatalf("VerifySignedMessage invalid signature %x", pr.Signature[:]) + } +} + +func BenchmarkSignMessage(b *testing.B) { + pr := fakePR(b) + + for b.Loop() { + err := SignMessage(pr, testPrivKey) + if err != nil { + b.Fatalf("Failed to sign message: %v", err) + } + } +} + +func BenchmarkVerifySignedMessage(b *testing.B) { + pr := fakePR(b) + err := SignMessage(pr, testPrivKey) + if err != nil { + b.Fatalf("Failed to sign message: %v", err) + } + + for b.Loop() { + if !VerifySignedMessage(pr) { + b.Fatalf("VerifySignedMessage invalid signature %x", pr.Signature[:]) + } + } +} diff --git a/mixing/utxoproof/utxoproof.go b/mixing/utxoproof/utxoproof.go index e0a33ec38e..ce4d88b239 100644 --- a/mixing/utxoproof/utxoproof.go +++ b/mixing/utxoproof/utxoproof.go @@ -5,8 +5,6 @@ package utxoproof import ( - "encoding/binary" - "github.com/decred/dcrd/crypto/blake256" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4/schnorr" @@ -23,7 +21,7 @@ const ( secp256k1P2PKH = "P2PKH(EC-Schnorr-DCRv0)" ) -var sep = []byte{','} +const sep = "," // The signature hash is created from the serialization of: // tag , scheme , expiry pubkey @@ -43,17 +41,14 @@ type Secp256k1KeyPair struct { func (k *Secp256k1KeyPair) SignUtxoProof(expires uint32) ([]byte, error) { const scheme = secp256k1P2PKH - h := blake256.New() - h.Write([]byte(tag)) - h.Write(sep) - h.Write([]byte(scheme)) - h.Write(sep) - expiresBytes := binary.BigEndian.AppendUint32(make([]byte, 0, 4), expires) - h.Write(expiresBytes) - h.Write(k.Pub) - hash := h.Sum(nil) + h := blake256.NewHasher256() + const preamble = tag + sep + scheme + sep + h.WriteBytes([]byte(preamble)) + h.WriteUint32BE(expires) + h.WriteBytes(k.Pub) + hash := h.Sum256() - sig, err := schnorr.Sign(k.Priv, hash) + sig, err := schnorr.Sign(k.Priv, hash[:]) if err != nil { return nil, err } @@ -76,15 +71,12 @@ func ValidateSecp256k1P2PKH(pubkey, proof []byte, expires uint32) bool { return false } - h := blake256.New() - h.Write([]byte(tag)) - h.Write(sep) - h.Write([]byte(scheme)) - h.Write(sep) - expiresBytes := binary.BigEndian.AppendUint32(make([]byte, 0, 4), expires) - h.Write(expiresBytes) + h := blake256.NewHasher256() + const preamble = tag + sep + scheme + sep + h.WriteBytes([]byte(preamble)) + h.WriteUint32BE(expires) h.Write(pubkey) - hash := h.Sum(nil) + hash := h.Sum256() - return proofParsed.Verify(hash, pubkeyParsed) + return proofParsed.Verify(hash[:], pubkeyParsed) } diff --git a/mixing/utxoproof/utxoproof_bench_test.go b/mixing/utxoproof/utxoproof_bench_test.go new file mode 100644 index 0000000000..0f56ce03f4 --- /dev/null +++ b/mixing/utxoproof/utxoproof_bench_test.go @@ -0,0 +1,57 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package utxoproof + +import ( + "testing" + + "github.com/decred/dcrd/dcrec/secp256k1/v4" +) + +// genSecp256k1KeyPair generates and returns a new secp256k1 key pair. +func genSecp256k1KeyPair(b *testing.B) *Secp256k1KeyPair { + privKey, err := secp256k1.GeneratePrivateKey() + if err != nil { + b.Fatalf("failed to generate key pair: %v", err) + } + return &Secp256k1KeyPair{ + Pub: privKey.PubKey().SerializeCompressed(), + Priv: privKey, + } +} + +// BenchmarkValidateSecp256k1P2PKH benchmarks how long it takes to sign a utxo +// proof via [SignUtxoProof] along with the number of allocations needed. +func BenchmarkSignUtxoProof(b *testing.B) { + const expires = 10 + keyPair := genSecp256k1KeyPair(b) + + b.ReportAllocs() + for b.Loop() { + _, err := keyPair.SignUtxoProof(expires) + if err != nil { + b.Fatalf("failed to sign utxo proof: %v", err) + } + } +} + +// BenchmarkValidateSecp256k1P2PKH benchmarks how long it takes to validate a +// secp256k1 p2pkh utxo proof via [ValidateSecp256k1P2PKH] along with the number +// of allocations needed. +func BenchmarkValidateSecp256k1P2PKH(b *testing.B) { + const expires = 10 + keyPair := genSecp256k1KeyPair(b) + proof, err := keyPair.SignUtxoProof(expires) + if err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + for b.Loop() { + if !ValidateSecp256k1P2PKH(keyPair.Pub, proof, expires) { + b.Fatal("failed to validate utxo proof") + } + } +} diff --git a/rpcadaptors.go b/rpcadaptors.go index 8f411fe737..66e81ddbe0 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -1,5 +1,5 @@ // Copyright (c) 2017 The btcsuite developers -// Copyright (c) 2015-2025 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -22,6 +22,7 @@ import ( "github.com/decred/dcrd/internal/netsync" "github.com/decred/dcrd/internal/rpcserver" "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/mixpool" "github.com/decred/dcrd/peer/v3" "github.com/decred/dcrd/wire" ) @@ -444,9 +445,10 @@ func (b *rpcSyncMgr) RecentlyConfirmedTxn(hash *chainhash.Hash) bool { return b.server.recentlyConfirmedTxns.Contains(hash[:]) } -// SubmitMixMessage locally processes the mixing message. -func (b *rpcSyncMgr) SubmitMixMessage(msg mixing.Message) error { - _, err := b.server.mixMsgPool.AcceptMessage(msg) +// AcceptMixMessage attempts to accept a mixing message to the local mixing +// pool. +func (b *rpcSyncMgr) AcceptMixMessage(msg mixing.Message, src mixpool.Source) error { + _, err := b.server.mixMsgPool.AcceptMessage(msg, src) return err } diff --git a/server.go b/server.go index f61d638e1d..52c1c8cc8d 100644 --- a/server.go +++ b/server.go @@ -933,11 +933,12 @@ func (sp *serverPeer) Run() { srvr.syncManager.OnPeerDisconnected(sp.syncMgrPeer) if sp.VersionKnown() { - // Evict any remaining orphans that were sent by the peer. + // Evict any remaining mempool orphans that were sent by the peer. numEvicted := srvr.txMemPool.RemoveOrphansByTag(mempool.Tag(sp.ID())) if numEvicted > 0 { - srvrLog.Debugf("Evicted %d %s from peer %v (id %d)", numEvicted, - pickNoun(numEvicted, "orphan", "orphans"), sp, sp.ID()) + srvrLog.Debugf("Evicted %d mempool %s from peer %v (id %d)", + numEvicted, pickNoun(numEvicted, "orphan", "orphans"), sp, + sp.ID()) } }