diff --git a/chain/sync.go b/chain/sync.go index fd335763c..71a10ffa7 100644 --- a/chain/sync.go +++ b/chain/sync.go @@ -24,6 +24,7 @@ import ( "github.com/decred/dcrd/blockchain/stake/v5" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/crypto/blake256" + "github.com/decred/dcrd/gcs/v4" "github.com/decred/dcrd/mixing/mixpool" "github.com/decred/dcrd/wire" "github.com/jrick/wsrpc/v2" @@ -223,6 +224,11 @@ func (s *Syncer) getHeaders(ctx context.Context) error { return err } + birthday, err := s.wallet.BirthState(ctx) + if err != nil { + return err + } + startedSynced := s.walletSynced.Load() cnet := s.wallet.ChainParams().Net @@ -253,17 +259,25 @@ func (s *Syncer) getHeaders(ctx context.Context) error { var g errgroup.Group for i := range headers { g.Go(func() error { + var err error header := headers[i] hash := header.BlockHash() - filter, proofIndex, proof, err := s.rpc.CFilterV2(ctx, &hash) - if err != nil { - return err - } - - err = validate.CFilterV2HeaderCommitment(cnet, header, - filter, proofIndex, proof) - if err != nil { - return err + var filter *gcs.FilterV2 + if birthday == nil || birthday.AfterBirthday(header) { + var ( + proofIndex uint32 + proof []chainhash.Hash + ) + filter, proofIndex, proof, err = s.rpc.CFilterV2(ctx, &hash) + if err != nil { + return err + } + + err = validate.CFilterV2HeaderCommitment(cnet, header, + filter, proofIndex, proof) + if err != nil { + return err + } } nodes[i] = wallet.NewBlockNode(header, &hash, filter, nil) diff --git a/rpc/documentation/api.md b/rpc/documentation/api.md index 4cb198789..b371cd8ac 100644 --- a/rpc/documentation/api.md +++ b/rpc/documentation/api.md @@ -1109,6 +1109,8 @@ ___ The `ImportPrivateKey` method imports a private key in Wallet Import Format (WIF) encoding to a wallet account. A rescan may optionally be started to search for transactions involving the private key's associated payment address. +If the private key deals with transactions before the wallet birthday, if set, +a rescan must be performed to download missing cfilters. **Request:** `ImportPrivateKeyRequest` @@ -1144,7 +1146,9 @@ ___ The `ImportScript` method imports a script into the wallet. A rescan may optionally be started to search for transactions involving the script, either -as an output or in a P2SH input. +as an output or in a P2SH input. If the script deals with transactions before +the wallet birthday, if set, a rescan must be performed to download missing +cfilters. **Request:** `ImportScriptRequest` @@ -1191,7 +1195,9 @@ seed for a hierarchical deterministic private key that is imported into the wallet with the supplied name and locked with the supplied password. Addresses derived from this account MUST NOT be sent any funds. They are solely for the use of creating stake submission scripts. A rescan may optionally be started to -search for tickets using submission scripts derived from this account. +search for tickets using submission scripts derived from this account. If tickets +would exist before the wallet birthday, if set, a rescan must be performed to +download missing cfilters. **Request:** `ImportVotingAccountFromSeedRequest` @@ -2690,7 +2696,10 @@ or account must be unlocked. #### `BirthBlock` The `BirthBlock` method returns the wallets birthday block if set. Rescans -should generally be started from after this block. +should generally be started from after this block. If a birthday is set cfilters +from before the birthday may not be downloaded. A rescan from height will move +the birthday to the rescan height and download all missing cfilters from that +height. **Request:** `BirthBlockRequest` diff --git a/spv/sync.go b/spv/sync.go index 265dc4e32..4c443d11e 100644 --- a/spv/sync.go +++ b/spv/sync.go @@ -1662,6 +1662,11 @@ func (s *Syncer) initialSyncHeaders(ctx context.Context) error { return res } + birthday, err := s.wallet.BirthState(ctx) + if err != nil { + return err + } + // Stage 1: fetch headers. headersChan := make(chan *headersBatch) g.Go(func() error { @@ -1737,9 +1742,11 @@ func (s *Syncer) initialSyncHeaders(ctx context.Context) error { s.sidechainMu.Lock() var missingCfilter []*wallet.BlockNode for i := range batch.bestChain { - if batch.bestChain[i].FilterV2 == nil { - missingCfilter = batch.bestChain[i:] - break + if birthday == nil || birthday.AfterBirthday(batch.bestChain[i].Header) { + if batch.bestChain[i].FilterV2 == nil { + missingCfilter = batch.bestChain[i:] + break + } } } s.sidechainMu.Unlock() diff --git a/wallet/rescan.go b/wallet/rescan.go index f73ac67df..cb445d76c 100644 --- a/wallet/rescan.go +++ b/wallet/rescan.go @@ -386,8 +386,71 @@ func (w *Wallet) Rescan(ctx context.Context, n NetworkBackend, startHash *chainh func (w *Wallet) RescanFromHeight(ctx context.Context, n NetworkBackend, startHeight int32) error { const op errors.Op = "wallet.RescanFromHeight" + bs, err := w.BirthState(ctx) + if err != nil { + return errors.E(op, err) + } + // Determine if the rescan start height is before the birthday. + // For time-based birthdays that have not been resolved to a + // height, look up the block header and compare timestamps. + beforeBirthday := false + if bs != nil { + if bs.SetFromTime { + var header *wire.BlockHeader + err = walletdb.View(ctx, w.db, func(tx walletdb.ReadTx) error { + ns := tx.ReadBucket(wtxmgrNamespaceKey) + hash, err := w.txStore.GetMainChainBlockHashForHeight(ns, startHeight) + if err != nil { + return err + } + header, err = w.txStore.GetBlockHeader(tx, &hash) + return err + }) + if err != nil { + return errors.E(op, err) + } + beforeBirthday = !bs.AfterBirthday(header) + } else { + beforeBirthday = int32(bs.Height) > startHeight + } + } + if beforeBirthday { + // If our birthday is after the rescan height, we may + // not have the cfilters needed. Set birthday to the rescan + // height and download the filters. This may take some time + // depending on network conditions and amount of filters missing. + newBS := &udb.BirthdayState{ + SetFromHeight: true, + Height: uint32(startHeight), + } + if err := w.SetBirthStateAndScan(ctx, newBS); err != nil { + return errors.E(op, err) + } + fetchMissing := true + if err := walletdb.Update(ctx, w.db, func(dbtx walletdb.ReadWriteTx) error { + if _, err := udb.MissingCFiltersHeight(dbtx, startHeight); err != nil { + // errors.NotExist is returned if no missing filters + // exist from start height. If we have them there is + // no need to fetch them again. + if errors.Is(err, errors.NotExist) { + fetchMissing = false + return nil + } + return err + } + return w.txStore.SetHaveMainChainCFilters(dbtx, false) + }); err != nil { + return errors.E(op, err) + } + if fetchMissing { + if err := w.FetchMissingCFilters(ctx, n); err != nil { + return errors.E(op, err) + } + } + } + var startHash chainhash.Hash - err := walletdb.View(ctx, w.db, func(tx walletdb.ReadTx) error { + err = walletdb.View(ctx, w.db, func(tx walletdb.ReadTx) error { txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey) var err error startHash, err = w.txStore.GetMainChainBlockHashForHeight( diff --git a/wallet/udb/txmined.go b/wallet/udb/txmined.go index de20abcbe..b5d0834be 100644 --- a/wallet/udb/txmined.go +++ b/wallet/udb/txmined.go @@ -198,8 +198,8 @@ func (s *Store) MainChainTip(dbtx walletdb.ReadTx) (chainhash.Hash, int32) { // If the block is already inserted and part of the main chain, an errors.Exist // error is returned. // -// The main chain tip may not be extended unless compact filters have been saved -// for all existing main chain blocks. +// The main chain may be extended without cfilters if this block is before the +// wallet birthday. If the filter is nil it will not be saved to the database. func (s *Store) ExtendMainChain(ns walletdb.ReadWriteBucket, header *wire.BlockHeader, blockHash *chainhash.Hash, f *gcs2.FilterV2) error { height := int32(header.Height) if height < 1 { @@ -266,9 +266,12 @@ func (s *Store) ExtendMainChain(ns walletdb.ReadWriteBucket, header *wire.BlockH return err } - // Save the compact filter. - bcf2Key := blockcf2.Key(&header.MerkleRoot) - return putRawCFilter(ns, blockHash[:], valueRawCFilter2(bcf2Key, f.Bytes())) + // Save the compact filter if we have it. + if f != nil { + bcf2Key := blockcf2.Key(&header.MerkleRoot) + return putRawCFilter(ns, blockHash[:], valueRawCFilter2(bcf2Key, f.Bytes())) + } + return nil } // ProcessedTxsBlockMarker returns the hash of the block which records the last @@ -331,6 +334,17 @@ type BirthdayState struct { SetFromHeight, SetFromTime bool } +// AfterBirthday returns whether the given block header is at or after the +// birthday. If SetFromTime is true, the header's timestamp is compared against +// the birthday time. Otherwise, the header's height is compared against the +// birthday height regardless of the SetFromHeight flag. +func (bs *BirthdayState) AfterBirthday(h *wire.BlockHeader) bool { + if bs.SetFromTime { + return !h.Timestamp.Before(bs.Time) + } + return h.Height >= bs.Height +} + // SetBirthState sets the birthday state in the database. *BirthdayState must // not be nil. // @@ -402,19 +416,37 @@ func (s *Store) IsMissingMainChainCFilters(dbtx walletdb.ReadTx) bool { return len(v) != 1 || v[0] == 0 } +// SetHaveMainChainCFilters sets whether we have all of the main chain +// cfilters. Should be used to set have to false if the wallet birthday is +// moved back in time. +func (s *Store) SetHaveMainChainCFilters(dbtx walletdb.ReadWriteTx, have bool) error { + haveB := []byte{0} + if have { + haveB = []byte{1} + } + err := dbtx.ReadWriteBucket(wtxmgrBucketKey).Put(rootHaveCFilters, haveB) + if err != nil { + return errors.E(errors.IO, err) + } + return nil +} + // MissingCFiltersHeight returns the first main chain block height // with a missing cfilter. Errors with NotExist when all main chain // blocks record cfilters. -func (s *Store) MissingCFiltersHeight(dbtx walletdb.ReadTx) (int32, error) { +func MissingCFiltersHeight(dbtx walletdb.ReadTx, fromHeight int32) (int32, error) { ns := dbtx.ReadBucket(wtxmgrBucketKey) c := ns.NestedReadBucket(bucketBlocks).ReadCursor() defer c.Close() - for k, v := c.First(); k != nil; k, v = c.Next() { + for k, v := c.Seek(keyBlockRecord(fromHeight)); k != nil; k, v = c.Next() { hash := extractRawBlockRecordHash(v) _, _, err := fetchRawCFilter2(ns, hash) - if errors.Is(err, errors.NotExist) { - height := int32(byteOrder.Uint32(k)) - return height, nil + if err != nil { + if errors.Is(err, errors.NotExist) { + height := int32(byteOrder.Uint32(k)) + return height, nil + } + return 0, errors.E(errors.IO, err) } } return 0, errors.E(errors.NotExist) @@ -442,42 +474,40 @@ func (s *Store) InsertMissingCFilters(dbtx walletdb.ReadWriteTx, blockHashes []* } for i, blockHash := range blockHashes { - // Ensure that blockHashes are ordered and that all previous cfilters in the - // main chain are known. + // Ensure that blockHashes are ordered. The first block in + // the batch is not required to have its parent's cfilter + // already present, as pre-birthday blocks intentionally + // have no cfilters. + header := existsBlockHeader(ns, blockHash[:]) + if header == nil { + return errors.E(errors.NotExist, errors.Errorf("missing header for block %v", blockHash)) + } ok := i == 0 && *blockHash == s.chainParams.GenesisHash - var bcf2Key [gcs2.KeySize]byte if !ok { - header := existsBlockHeader(ns, blockHash[:]) - if header == nil { - return errors.E(errors.NotExist, errors.Errorf("missing header for block %v", blockHash)) - } parentHash := extractBlockHeaderParentHash(header) - merkleRoot := extractBlockHeaderMerkleRoot(header) - merkleRootHash, err := chainhash.NewHash(merkleRoot) - if err != nil { - return errors.E(errors.Invalid, errors.Errorf("invalid stored header %v", blockHash)) - } - bcf2Key = blockcf2.Key(merkleRootHash) - if i == 0 { - _, _, err := fetchRawCFilter2(ns, parentHash) - ok = err == nil - } else { - ok = bytes.Equal(parentHash, blockHashes[i-1][:]) + if i != 0 { + if !bytes.Equal(parentHash, blockHashes[i-1][:]) { + return errors.E(errors.Invalid, "block hashes are not ordered") + } } } - if !ok { - return errors.E(errors.Invalid, "block hashes are not ordered or previous cfilters are missing") - } // Record cfilter for this block - err := putRawCFilter(ns, blockHash[:], valueRawCFilter2(bcf2Key, filters[i].Bytes())) + merkleRoot := extractBlockHeaderMerkleRoot(header) + merkleRootHash, err := chainhash.NewHash(merkleRoot) + if err != nil { + return errors.E(errors.Invalid, errors.Errorf("invalid stored header %v", blockHash)) + } + bcf2Key := blockcf2.Key(merkleRootHash) + err = putRawCFilter(ns, blockHash[:], valueRawCFilter2(bcf2Key, filters[i].Bytes())) if err != nil { return err } } // Mark all main chain cfilters as saved if the last block hash is the main - // chain tip. + // chain tip. Even if this is not the head block, all cfilters may be saved + // at this point. The caller may need to check and set rootHaveCFilters. tip, _ := s.MainChainTip(dbtx) if bytes.Equal(tip[:], blockHashes[len(blockHashes)-1][:]) { err := ns.Put(rootHaveCFilters, []byte{1}) diff --git a/wallet/udb/txmined_test.go b/wallet/udb/txmined_test.go index d3815de2b..cd8ee9d44 100644 --- a/wallet/udb/txmined_test.go +++ b/wallet/udb/txmined_test.go @@ -12,6 +12,8 @@ import ( "decred.org/dcrwallet/v5/wallet/walletdb" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/crypto/rand" + "github.com/decred/dcrd/dcrutil/v4" + "github.com/decred/dcrd/wire" ) func randomBytes(len int) []byte { @@ -78,3 +80,429 @@ func TestSetBirthState(t *testing.T) { }) } } + +func TestMissingCFiltersHeight(t *testing.T) { + ctx := context.Background() + db, _, s, err := cloneDB(ctx, t, "mgr_watching_only.kv") + if err != nil { + t.Fatal(err) + } + + fakeFilter := [16]byte{} + + g := makeBlockGenerator() + b1H := g.generate(dcrutil.BlockValid) + b2H := g.generate(dcrutil.BlockValid) + b3H := g.generate(dcrutil.BlockValid) + b4H := g.generate(dcrutil.BlockValid) + b5H := g.generate(dcrutil.BlockValid) + headerData := makeHeaderDataSlice(b1H, b2H, b3H, b4H, b5H) + filters := emptyFilters(5) + + err = walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + err = insertMainChainHeaders(s, dbtx, headerData, filters) + if err != nil { + return err + } + // Delete filters for block 3 and 4. + ns := dbtx.ReadWriteBucket(wtxmgrBucketKey) + b3Hash := b3H.BlockHash() + if err := ns.NestedReadWriteBucket(bucketCFilters).Delete(b3Hash[:]); err != nil { + return err + } + b4Hash := b4H.BlockHash() + if err := ns.NestedReadWriteBucket(bucketCFilters).Delete(b4Hash[:]); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + missingNo, from int32 + wantErr bool + do func() + }{{ + name: "ok from 0", + missingNo: 3, + }, { + name: "ok from mid", + from: 4, + missingNo: 4, + }, { + name: "ok from 1 after adding", + from: 1, + do: func() { + if err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + ns := dbtx.ReadWriteBucket(wtxmgrBucketKey) + b3Hash := b3H.BlockHash() + err := putRawCFilter(ns, b3Hash[:], fakeFilter[:]) + if err != nil { + return err + } + return nil + }); err != nil { + t.Fatal(err) + } + }, + missingNo: 4, + }, { + name: "error once all filters full", + do: func() { + if err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + ns := dbtx.ReadWriteBucket(wtxmgrBucketKey) + b4Hash := b4H.BlockHash() + err := putRawCFilter(ns, b4Hash[:], fakeFilter[:]) + if err != nil { + return err + } + return nil + }); err != nil { + t.Fatal(err) + } + }, + wantErr: true, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var missingNo int32 + if test.do != nil { + test.do() + } + err = walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + var err error + missingNo, err = MissingCFiltersHeight(dbtx, test.from) + return err + }) + if test.wantErr { + if err == nil { + t.Fatal("wanted error but got none") + } + return + } + if err != nil { + t.Fatal(err) + } + if missingNo != test.missingNo { + t.Fatalf("wanted missing number %v but got %v", test.missingNo, missingNo) + } + }) + } +} + +func TestAfterBirthday(t *testing.T) { + t.Parallel() + + baseTime := time.Unix(1000000, 0) + baseHeight := uint32(500) + + tests := []struct { + name string + bs *BirthdayState + header *wire.BlockHeader + wantAfter bool + }{{ + name: "height based - header at birthday height", + bs: &BirthdayState{ + Height: baseHeight, + SetFromTime: false, + }, + header: &wire.BlockHeader{ + Height: baseHeight, + }, + wantAfter: true, + }, { + name: "height based - header after birthday height", + bs: &BirthdayState{ + Height: baseHeight, + SetFromTime: false, + }, + header: &wire.BlockHeader{ + Height: baseHeight + 100, + }, + wantAfter: true, + }, { + name: "height based - header before birthday height", + bs: &BirthdayState{ + Height: baseHeight, + SetFromTime: false, + }, + header: &wire.BlockHeader{ + Height: baseHeight - 1, + }, + wantAfter: false, + }, { + name: "time based - header after birthday time", + bs: &BirthdayState{ + Time: baseTime, + SetFromTime: true, + }, + header: &wire.BlockHeader{ + Timestamp: baseTime.Add(time.Hour), + }, + wantAfter: true, + }, { + name: "time based - header at birthday time", + bs: &BirthdayState{ + Time: baseTime, + SetFromTime: true, + }, + header: &wire.BlockHeader{ + Timestamp: baseTime, + }, + wantAfter: true, + }, { + name: "time based - header before birthday time", + bs: &BirthdayState{ + Time: baseTime, + SetFromTime: true, + }, + header: &wire.BlockHeader{ + Timestamp: baseTime.Add(-time.Hour), + }, + wantAfter: false, + }, { + name: "SetFromHeight true uses height comparison", + bs: &BirthdayState{ + Height: baseHeight, + Time: baseTime, + SetFromHeight: true, + SetFromTime: false, + }, + header: &wire.BlockHeader{ + Height: baseHeight + 1, + Timestamp: baseTime.Add(-time.Hour), // before time but after height + }, + wantAfter: true, + }, { + name: "SetFromTime takes precedence", + bs: &BirthdayState{ + Height: baseHeight, + Time: baseTime, + SetFromHeight: true, + SetFromTime: true, + }, + header: &wire.BlockHeader{ + Height: baseHeight + 100, // after height + Timestamp: baseTime.Add(-time.Hour), // but before time + }, + wantAfter: false, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + got := test.bs.AfterBirthday(test.header) + if got != test.wantAfter { + t.Errorf("AfterBirthday() = %v, want %v", got, test.wantAfter) + } + }) + } +} + +func TestExtendMainChainNilFilter(t *testing.T) { + ctx := context.Background() + db, _, s, err := cloneDB(ctx, t, "mgr_watching_only.kv") + if err != nil { + t.Fatal(err) + } + + g := makeBlockGenerator() + b1H := g.generate(dcrutil.BlockValid) + b2H := g.generate(dcrutil.BlockValid) + b3H := g.generate(dcrutil.BlockValid) + headerData := makeHeaderDataSlice(b1H, b2H) + filters := emptyFilters(2) + + // Insert first two blocks with filters. + err = walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + return insertMainChainHeaders(s, dbtx, headerData, filters) + }) + if err != nil { + t.Fatal(err) + } + + // Insert third block with nil filter. + err = walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + ns := dbtx.ReadWriteBucket(wtxmgrBucketKey) + b3Hash := b3H.BlockHash() + return s.ExtendMainChain(ns, b3H, &b3Hash, nil) + }) + if err != nil { + t.Fatal(err) + } + + // Verify header exists but cfilter does not. + err = walletdb.View(ctx, db, func(dbtx walletdb.ReadTx) error { + ns := dbtx.ReadBucket(wtxmgrBucketKey) + b3Hash := b3H.BlockHash() + header := existsBlockHeader(ns, b3Hash[:]) + if header == nil { + t.Fatal("expected header to exist for block 3") + } + _, _, err := fetchRawCFilter2(ns, b3Hash[:]) + if err == nil { + t.Fatal("expected no cfilter for block 3") + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestSetHaveMainChainCFilters(t *testing.T) { + ctx := context.Background() + db, _, s, err := cloneDB(ctx, t, "mgr_watching_only.kv") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + have bool + missing bool + }{{ + name: "set have false", + have: false, + missing: true, + }, { + name: "set have true", + have: true, + missing: false, + }, { + name: "toggle back to false", + have: false, + missing: true, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + if err := s.SetHaveMainChainCFilters(dbtx, test.have); err != nil { + return err + } + got := s.IsMissingMainChainCFilters(dbtx) + if got != test.missing { + t.Fatalf("IsMissingMainChainCFilters = %v, want %v", got, test.missing) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestInsertMissingCFilters(t *testing.T) { + ctx := context.Background() + db, _, s, err := cloneDB(ctx, t, "mgr_watching_only.kv") + if err != nil { + t.Fatal(err) + } + + g := makeBlockGenerator() + b1H := g.generate(dcrutil.BlockValid) + b2H := g.generate(dcrutil.BlockValid) + b3H := g.generate(dcrutil.BlockValid) + b4H := g.generate(dcrutil.BlockValid) + b5H := g.generate(dcrutil.BlockValid) + + // Insert headers without filters. + err = walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + ns := dbtx.ReadWriteBucket(wtxmgrBucketKey) + headers := []*wire.BlockHeader{b1H, b2H, b3H, b4H, b5H} + for _, h := range headers { + hash := h.BlockHash() + if err := s.ExtendMainChain(ns, h, &hash, nil); err != nil { + return err + } + } + return s.SetHaveMainChainCFilters(dbtx, false) + }) + if err != nil { + t.Fatal(err) + } + + b1Hash := b1H.BlockHash() + b2Hash := b2H.BlockHash() + b3Hash := b3H.BlockHash() + b4Hash := b4H.BlockHash() + b5Hash := b5H.BlockHash() + + t.Run("mismatched lengths", func(t *testing.T) { + err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + return s.InsertMissingCFilters(dbtx, + []*chainhash.Hash{&b1Hash}, + emptyFilters(2), + ) + }) + if err == nil { + t.Fatal("expected error for mismatched lengths") + } + }) + + t.Run("empty slices", func(t *testing.T) { + err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + return s.InsertMissingCFilters(dbtx, nil, nil) + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("out of order", func(t *testing.T) { + err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + return s.InsertMissingCFilters(dbtx, + []*chainhash.Hash{&b2Hash, &b1Hash}, + emptyFilters(2), + ) + }) + if err == nil { + t.Fatal("expected error for out of order blocks") + } + }) + + t.Run("partial insert does not auto-mark", func(t *testing.T) { + err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + err := s.InsertMissingCFilters(dbtx, + []*chainhash.Hash{&b1Hash, &b2Hash, &b3Hash}, + emptyFilters(3), + ) + if err != nil { + return err + } + if !s.IsMissingMainChainCFilters(dbtx) { + t.Fatal("expected IsMissingMainChainCFilters=true after partial insert") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("insert to tip auto-marks complete", func(t *testing.T) { + err := walletdb.Update(ctx, db, func(dbtx walletdb.ReadWriteTx) error { + err := s.InsertMissingCFilters(dbtx, + []*chainhash.Hash{&b4Hash, &b5Hash}, + emptyFilters(2), + ) + if err != nil { + return err + } + if s.IsMissingMainChainCFilters(dbtx) { + t.Fatal("expected IsMissingMainChainCFilters=false after inserting to tip") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) +} diff --git a/wallet/wallet.go b/wallet/wallet.go index dbbb9096b..10a7a4c65 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -1292,9 +1292,14 @@ func (w *Wallet) fetchMissingCFilters(ctx context.Context, n NetworkBackend, pro err := walletdb.View(ctx, w.db, func(dbtx walletdb.ReadTx) error { var err error + var fromHeight int32 + birthday := udb.BirthState(dbtx) + if birthday != nil && !birthday.SetFromTime { + fromHeight = int32(birthday.Height) + } missing = w.txStore.IsMissingMainChainCFilters(dbtx) if missing { - height, err = w.txStore.MissingCFiltersHeight(dbtx) + height, err = udb.MissingCFiltersHeight(dbtx, fromHeight) } return err }) @@ -1317,7 +1322,7 @@ func (w *Wallet) fetchMissingCFilters(ctx context.Context, n NetworkBackend, pro } var hashes []chainhash.Hash var get []*chainhash.Hash - var cont bool + var alreadyHave, markHave bool err := walletdb.View(ctx, w.db, func(dbtx walletdb.ReadTx) error { ns := dbtx.ReadBucket(wtxmgrNamespaceKey) var err error @@ -1331,8 +1336,18 @@ func (w *Wallet) fetchMissingCFilters(ctx context.Context, n NetworkBackend, pro } _, _, err = w.txStore.CFilterV2(dbtx, &hash) if err == nil { - height += span - cont = true + // If there is a gap for some reason, continue from the end of the gap. + height, err = udb.MissingCFiltersHeight(dbtx, height) + if err != nil { + if errors.Is(err, errors.NotExist) { + // We have all the filters. + missing = false + markHave = true + return nil + } + return err + } + alreadyHave = true return nil } storage = storage[:cap(storage)] @@ -1354,10 +1369,18 @@ func (w *Wallet) fetchMissingCFilters(ctx context.Context, n NetworkBackend, pro if err != nil { return err } + if markHave { + if err := walletdb.Update(ctx, w.db, func(dbtx walletdb.ReadWriteTx) error { + return w.txStore.SetHaveMainChainCFilters(dbtx, true) + }); err != nil { + return err + } + return nil + } if !missing { return nil } - if cont { + if alreadyHave { continue } @@ -1391,24 +1414,26 @@ func (w *Wallet) fetchMissingCFilters(ctx context.Context, n NetworkBackend, pro } err = walletdb.Update(ctx, w.db, func(dbtx walletdb.ReadWriteTx) error { - _, _, err := w.txStore.CFilterV2(dbtx, get[len(get)-1]) - if err == nil { - cont = true - return nil + if err := w.txStore.InsertMissingCFilters(dbtx, get, filters); err != nil { + return err } - return w.txStore.InsertMissingCFilters(dbtx, get, filters) + missing = w.txStore.IsMissingMainChainCFilters(dbtx) + return nil }) if err != nil { return err } - if cont { - continue + endHeight := height + int32(len(filters)) - 1 + if progress != nil { + progress <- MissingCFilterProgress{BlockHeightStart: height, BlockHeightEnd: endHeight} } + log.Infof("Fetched cfilters for blocks %v-%v", height, endHeight) - if progress != nil { - progress <- MissingCFilterProgress{BlockHeightStart: height, BlockHeightEnd: height + span - 1} + if !missing { + return nil } - log.Infof("Fetched cfilters for blocks %v-%v", height, height+span-1) + + height = endHeight + 1 } } diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index de5f2eabd..d2aaf7149 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -229,3 +229,264 @@ func TestSetBirthStateAndScan(t *testing.T) { }) } } + +func TestRescanFromHeight(t *testing.T) { + t.Parallel() + ctx := context.Background() + + cfg := basicWalletConfig + w := testWallet(ctx, t, &cfg, nil) + + tg := maketg(t, cfg.Params) + tw := &tw{t, w} + forest := new(SidechainForest) + + for i := 1; i < 10; i++ { + name := fmt.Sprintf("%va", i) + b := tg.nextBlock(name, nil, nil) + mustAddBlockNode(t, forest, b.BlockNode) + t.Logf("Generated block %v name %q", b.Hash, name) + } + b9aHash := tg.blockHashByName("9a") + bestChain := tw.evaluateBestChain(ctx, forest, 9, b9aHash) + tw.chainSwitch(ctx, forest, bestChain) + tw.assertNoBetterChain(ctx, forest) + + b5Time := tg.BlockByName("5a").Header.Timestamp + + tests := []struct { + name string + bs *udb.BirthdayState + }{{ + name: "ok no birthday", + }, { + name: "ok birthday from height", + bs: &udb.BirthdayState{ + SetFromHeight: true, + Height: 5, + }, + }, { + name: "ok birthday from time", + bs: &udb.BirthdayState{ + SetFromTime: true, + Time: b5Time, + }, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.bs != nil { + err := w.SetBirthState(ctx, test.bs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + err := w.RescanFromHeight(ctx, mockNetwork{}, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestRescanFromHeightMovesBirthday(t *testing.T) { + t.Parallel() + ctx := context.Background() + + cfg := basicWalletConfig + w := testWallet(ctx, t, &cfg, nil) + + tg := maketg(t, cfg.Params) + tw := &tw{t, w} + forest := new(SidechainForest) + + for i := 1; i < 10; i++ { + name := fmt.Sprintf("%va", i) + b := tg.nextBlock(name, nil, nil) + mustAddBlockNode(t, forest, b.BlockNode) + } + b9aHash := tg.blockHashByName("9a") + bestChain := tw.evaluateBestChain(ctx, forest, 9, b9aHash) + tw.chainSwitch(ctx, forest, bestChain) + tw.assertNoBetterChain(ctx, forest) + + // Set birthday at height 7. + err := w.SetBirthState(ctx, &udb.BirthdayState{ + SetFromHeight: true, + Height: 7, + }) + if err != nil { + t.Fatal(err) + } + + // Rescan from height 3 should move birthday back to 3. + err = w.RescanFromHeight(ctx, mockNetwork{}, 3) + if err != nil { + t.Fatal(err) + } + + bs, err := w.BirthState(ctx) + if err != nil { + t.Fatal(err) + } + if bs == nil { + t.Fatal("expected birthday state to be set") + } + if bs.Height > 3 { + t.Fatalf("expected birthday height <= 3, got %v", bs.Height) + } +} + +func TestRescanFromHeightAfterBirthday(t *testing.T) { + t.Parallel() + ctx := context.Background() + + cfg := basicWalletConfig + w := testWallet(ctx, t, &cfg, nil) + + tg := maketg(t, cfg.Params) + tw := &tw{t, w} + forest := new(SidechainForest) + + for i := 1; i < 10; i++ { + name := fmt.Sprintf("%va", i) + b := tg.nextBlock(name, nil, nil) + mustAddBlockNode(t, forest, b.BlockNode) + } + b9aHash := tg.blockHashByName("9a") + bestChain := tw.evaluateBestChain(ctx, forest, 9, b9aHash) + tw.chainSwitch(ctx, forest, bestChain) + tw.assertNoBetterChain(ctx, forest) + + // Set birthday at height 3. + err := w.SetBirthState(ctx, &udb.BirthdayState{ + SetFromHeight: true, + Height: 3, + }) + if err != nil { + t.Fatal(err) + } + + // Rescan from height 5 (after birthday) should not move birthday. + err = w.RescanFromHeight(ctx, mockNetwork{}, 5) + if err != nil { + t.Fatal(err) + } + + bs, err := w.BirthState(ctx) + if err != nil { + t.Fatal(err) + } + if bs == nil { + t.Fatal("expected birthday state to be set") + } + if bs.Height != 3 { + t.Fatalf("expected birthday height 3, got %v", bs.Height) + } +} + +func TestRescanFromHeightMovesBirthdayFromTime(t *testing.T) { + t.Parallel() + ctx := context.Background() + + cfg := basicWalletConfig + w := testWallet(ctx, t, &cfg, nil) + + tg := maketg(t, cfg.Params) + tw := &tw{t, w} + forest := new(SidechainForest) + + for i := 1; i < 10; i++ { + name := fmt.Sprintf("%va", i) + b := tg.nextBlock(name, nil, nil) + mustAddBlockNode(t, forest, b.BlockNode) + } + b9aHash := tg.blockHashByName("9a") + bestChain := tw.evaluateBestChain(ctx, forest, 9, b9aHash) + tw.chainSwitch(ctx, forest, bestChain) + tw.assertNoBetterChain(ctx, forest) + + // Set a time-based birthday using block 7's timestamp. + b7Time := tg.BlockByName("7a").Header.Timestamp + err := w.SetBirthState(ctx, &udb.BirthdayState{ + SetFromTime: true, + Time: b7Time, + }) + if err != nil { + t.Fatal(err) + } + + // Rescan from height 3 (before the time-based birthday) should + // move the birthday back to 3. + err = w.RescanFromHeight(ctx, mockNetwork{}, 3) + if err != nil { + t.Fatal(err) + } + + bs, err := w.BirthState(ctx) + if err != nil { + t.Fatal(err) + } + if bs == nil { + t.Fatal("expected birthday state to be set") + } + if bs.Height > 3 { + t.Fatalf("expected birthday height <= 3, got %v", bs.Height) + } +} + +func TestRescanFromHeightAfterBirthdayFromTime(t *testing.T) { + t.Parallel() + ctx := context.Background() + + cfg := basicWalletConfig + w := testWallet(ctx, t, &cfg, nil) + + tg := maketg(t, cfg.Params) + tw := &tw{t, w} + forest := new(SidechainForest) + + for i := 1; i < 10; i++ { + name := fmt.Sprintf("%va", i) + b := tg.nextBlock(name, nil, nil) + mustAddBlockNode(t, forest, b.BlockNode) + } + b9aHash := tg.blockHashByName("9a") + bestChain := tw.evaluateBestChain(ctx, forest, 9, b9aHash) + tw.chainSwitch(ctx, forest, bestChain) + tw.assertNoBetterChain(ctx, forest) + + // Set a time-based birthday using block 3's timestamp. + b3Time := tg.BlockByName("3a").Header.Timestamp + err := w.SetBirthState(ctx, &udb.BirthdayState{ + SetFromTime: true, + Time: b3Time, + }) + if err != nil { + t.Fatal(err) + } + + // Rescan from height 5 (after the time-based birthday) should + // not move the birthday. + err = w.RescanFromHeight(ctx, mockNetwork{}, 5) + if err != nil { + t.Fatal(err) + } + + bs, err := w.BirthState(ctx) + if err != nil { + t.Fatal(err) + } + if bs == nil { + t.Fatal("expected birthday state to be set") + } + if bs.SetFromTime { + // Birthday should have been resolved during rescan setup, + // but if it wasn't touched it stays as time-based. Either + // way, it should not have been moved earlier than block 3. + return + } + if bs.Height < 3 { + t.Fatalf("expected birthday height >= 3, got %v", bs.Height) + } +}