diff --git a/config.go b/config.go index 7119df13c..7b5a47f5b 100644 --- a/config.go +++ b/config.go @@ -41,6 +41,7 @@ const ( const ( defaultCAFilename = "dcrd.cert" defaultConfigFilename = "dcrwallet.conf" + defaultDBDriver = "bdb" defaultLogLevel = "info" defaultLogDirname = "logs" defaultLogFilename = "dcrwallet.log" @@ -91,6 +92,7 @@ type config struct { AppDataDir *cfgutil.ExplicitString `short:"A" long:"appdata" description:"Application data directory for wallet config, databases and logs"` TestNet bool `long:"testnet" description:"Use the test network"` SimNet bool `long:"simnet" description:"Use the simulation test network"` + DBDriver string `long:"dbdriver" description:"Name of the database driver for the wallet db"` NoInitialLoad bool `long:"noinitialload" description:"Defer wallet creation/opening on startup and enable loading wallets over RPC"` DebugLevel string `short:"d" long:"debuglevel" description:"Logging level {trace, debug, info, warn, error, critical}"` LogDir *cfgutil.ExplicitString `long:"logdir" description:"Directory to log output."` @@ -346,6 +348,7 @@ func loadConfig(ctx context.Context) (*config, []string, error) { DebugLevel: defaultLogLevel, ConfigFile: cfgutil.NewExplicitString(defaultConfigFile), AppDataDir: cfgutil.NewExplicitString(defaultAppDataDir), + DBDriver: defaultDBDriver, LogDir: cfgutil.NewExplicitString(defaultLogDir), LogSize: defaultLogSize, WalletPass: wallet.InsecurePubPassphrase, diff --git a/dcrwallet.go b/dcrwallet.go index 2bb680448..18b28adc4 100644 --- a/dcrwallet.go +++ b/dcrwallet.go @@ -165,7 +165,7 @@ func run(ctx context.Context) error { // wallet. Otherwise, loading is deferred so it can be performed over RPC. dbDir := networkDir(cfg.AppDataDir.Value, activeNet.Params) - loader := ldr.NewLoader(activeNet.Params, dbDir, cfg.EnableVoting, + loader := ldr.NewLoader(activeNet.Params, dbDir, cfg.DBDriver, cfg.EnableVoting, cfg.GapLimit, cfg.WatchLast, cfg.AllowHighFees, cfg.RelayFee.Amount, cfg.VSPOpts.MaxFee.Amount, cfg.AccountGapLimit, cfg.DisableCoinTypeUpgrades, cfg.MixingEnabled, cfg.ManualTickets, diff --git a/go.mod b/go.mod index 889f5bbff..cb5162ecc 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/decred/slog v1.2.0 github.com/decred/vspd/client/v4 v4.0.1 github.com/decred/vspd/types/v3 v3.0.0 + github.com/dgraph-io/badger/v4 v4.8.0 github.com/gorilla/websocket v1.5.1 github.com/jessevdk/go-flags v1.5.0 github.com/jrick/bitset v1.0.0 @@ -45,13 +46,24 @@ require ( require ( github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/companyzero/sntrup4591761 v0.0.0-20220309191932-9e0f3af2f07a // indirect github.com/dchest/siphash v1.2.3 // indirect github.com/decred/base58 v1.0.6 // indirect github.com/decred/dcrd/container/lru v1.0.0 // indirect github.com/decred/dcrd/database/v3 v3.0.3 // indirect github.com/decred/dcrd/dcrec/edwards/v2 v2.0.4 // indirect + github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/flatbuffers v25.2.10+incompatible // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect diff --git a/go.sum b/go.sum index 333126a0f..725c4f059 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ decred.org/cspp/v2 v2.4.0 h1:whb0YW+UELHJS/UfT5MBXSJXrKUVw5omhgKNhjzYix4= decred.org/cspp/v2 v2.4.0/go.mod h1:9nO3bfvCheOPIFZw5f6sRQ42CjBFB5RKSaJ9Iq6G4MA= github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 h1:w1UutsfOrms1J05zt7ISrnJIXKzwaspym5BTKGx93EI= github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412/go.mod h1:WPjqKcmVOxf0XSf3YxCJs6N6AOSrOx3obionmG7T0y0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/companyzero/sntrup4591761 v0.0.0-20220309191932-9e0f3af2f07a h1:clYxJ3Os0EQUKDDVU8M0oipllX0EkuFNBfhVQuIfyF0= github.com/companyzero/sntrup4591761 v0.0.0-20220309191932-9e0f3af2f07a/go.mod h1:z/9Ck1EDixEbBbZ2KH2qNHekEmDLTOZ+FyoIPWWSVOI= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -68,6 +70,15 @@ github.com/decred/vspd/client/v4 v4.0.1 h1:eoFWCoaqEMLBODRQrVABGcpFrFdOSPLiMWpPO github.com/decred/vspd/client/v4 v4.0.1/go.mod h1:jhqu4KGGOskQcPVZ3XZLVZ1Wgkc9GQo+oEipr3gGODg= github.com/decred/vspd/types/v3 v3.0.0 h1:jHlQIpp6aCjIcFs8WE3AaVCJe1kgepNTq+nkBKAyQxk= github.com/decred/vspd/types/v3 v3.0.0/go.mod h1:hwifRZu6tpkbhSg2jZCUwuPaO/oETgbSCWCYJd4XepY= +github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs= +github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w= +github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM= +github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -76,6 +87,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -90,12 +103,14 @@ github.com/jrick/logrotate v1.0.0 h1:lQ1bL/n9mBNeIXoTUoYRlK4dHuNJVofX9oWqBtPnSzI github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/jrick/wsrpc/v2 v2.3.8 h1:9vfM8o9g00HXQb/3D6+Y9Cy1uybjD7K1272vtdXXBps= github.com/jrick/wsrpc/v2 v2.3.8/go.mod h1:Ha6uT2AOjHkaiBWMjWfWUFvjDrppbfy0ghLKxPPYmY4= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= go.etcd.io/bbolt v1.3.12 h1:UAxZAIuJqzFwByP19gZC3zd5robK3FOangrGS+Fdczg= diff --git a/internal/loader/loader.go b/internal/loader/loader.go index 5df2f5069..8349d9b1e 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -13,14 +13,14 @@ import ( "decred.org/dcrwallet/v5/errors" "decred.org/dcrwallet/v5/wallet" - _ "decred.org/dcrwallet/v5/wallet/drivers/bdb" // driver loaded during init + _ "decred.org/dcrwallet/v5/wallet/drivers/badgerdb" // driver loaded during init + _ "decred.org/dcrwallet/v5/wallet/drivers/bdb" "github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/dcrutil/v4" ) const ( walletDbName = "wallet.db" - driver = "bdb" ) // Loader implements the creating of new and opening of existing wallets, while @@ -34,6 +34,7 @@ type Loader struct { callbacks []func(*wallet.Wallet) chainParams *chaincfg.Params dbDirPath string + dbDriver string wallet *wallet.Wallet db wallet.DB @@ -54,13 +55,15 @@ type Loader struct { } // NewLoader constructs a Loader. -func NewLoader(chainParams *chaincfg.Params, dbDirPath string, votingEnabled bool, gapLimit uint32, +func NewLoader(chainParams *chaincfg.Params, dbDirPath string, dbDriver string, votingEnabled bool, gapLimit uint32, watchLast uint32, allowHighFees bool, relayFee dcrutil.Amount, vspMaxFee dcrutil.Amount, accountGapLimit int, - disableCoinTypeUpgrades bool, mixingEnabled bool, manualTickets bool, mixSplitLimit int, dialer wallet.DialFunc) *Loader { + disableCoinTypeUpgrades bool, mixingEnabled bool, manualTickets bool, mixSplitLimit int, + dialer wallet.DialFunc) *Loader { return &Loader{ chainParams: chainParams, dbDirPath: dbDirPath, + dbDriver: dbDriver, votingEnabled: votingEnabled, gapLimit: gapLimit, watchLast: watchLast, @@ -154,7 +157,7 @@ func (l *Loader) CreateWatchingOnlyWallet(ctx context.Context, extendedPubKey st if err != nil { return nil, errors.E(op, err) } - db, err := wallet.CreateDB(driver, dbPath) + db, err := wallet.CreateDB(l.dbDriver, dbPath) if err != nil { return nil, errors.E(op, err) } @@ -244,7 +247,7 @@ func (l *Loader) CreateNewWallet(ctx context.Context, pubPassphrase, privPassphr if err != nil { return nil, errors.E(op, err) } - db, err := wallet.CreateDB(driver, dbPath) + db, err := wallet.CreateDB(l.dbDriver, dbPath) if err != nil { return nil, errors.E(op, err) } @@ -298,7 +301,7 @@ func (l *Loader) OpenExistingWallet(ctx context.Context, pubPassphrase []byte) ( // Open the database using the boltdb backend. dbPath := filepath.Join(l.dbDirPath, walletDbName) l.mu.Unlock() - db, err := wallet.OpenDB(driver, dbPath) + db, err := wallet.OpenDB(l.dbDriver, dbPath) l.mu.Lock() if err != nil { diff --git a/wallet/drivers/badgerdb/driver.go b/wallet/drivers/badgerdb/driver.go new file mode 100644 index 000000000..bca1f785e --- /dev/null +++ b/wallet/drivers/badgerdb/driver.go @@ -0,0 +1,16 @@ +// Copyright (c) 2018-2025 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +// Package badgerdb registers the badgerdb driver at init time. Importing +// badgerdb allows the wallet.OpenDB and wallet.CreateDB functions to be +// called with the following arguments: +// +// var directory string +// db, err := wallet.CreateDB("badgerdb", directory) +// if err != nil { /* handle error */ } +// db, err = wallet.OpenDB("badgerdb", directory) +// if err != nil { /* handle error */ } +package badgerdb + +import _ "decred.org/dcrwallet/v5/wallet/internal/badgerdb" // Register badgerdb driver during init diff --git a/wallet/internal/badgerdb/db.go b/wallet/internal/badgerdb/db.go new file mode 100644 index 000000000..2fb6809ce --- /dev/null +++ b/wallet/internal/badgerdb/db.go @@ -0,0 +1,594 @@ +// Copyright (c) 2014 The btcsuite developers +// Copyright (c) 2015-2025 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package badgerdb + +import ( + "bytes" + "fmt" + "os" + + "decred.org/dcrwallet/v5/errors" + "decred.org/dcrwallet/v5/wallet/walletdb" + "github.com/dgraph-io/badger/v4" + "github.com/dgraph-io/badger/v4/options" +) + +const ( + metaBucket byte = 'b' +) + +// unsafely return value of an item. panics on value errors. +func unsafeValue(item *badger.Item) []byte { + var v []byte + // XXX: copy for testing + v, err := item.ValueCopy(nil) + // err := item.Value(func(value []byte) error { + // v = value + // return nil + // }) + if err != nil { + panic(fmt.Sprintf("item.Value: %v", err)) // XXX + } + return v +} + +// unsafely return stripped key and value for an item. +// panics on value errors. +func itemKV(stripPrefix []byte, item *badger.Item) (k, v []byte) { + k, v = item.Key(), unsafeValue(item) + if !bytes.HasPrefix(k, stripPrefix) { + return nil, nil + } + k = k[len(stripPrefix):] + // XXX: copy for testing + k = append(make([]byte, 0, len(k)), k...) + return k, v +} + +// creates the prefix for a top level bucket. +func topLevelPrefix(key []byte) []byte { + prefix := make([]byte, 0, len(key)+3) + prefix = append(prefix, '{') + prefix = append(prefix, key...) + prefix = append(prefix, "}/"...) + return prefix +} + +// append key to the bucket prefix, reusing the alloc when possible. +func reusePrefixedKey(prefix *[]byte, key []byte) []byte { + appendedKey := append(*prefix, key...) + *prefix = appendedKey[:len(*prefix)] + return appendedKey +} + +// append key to the bucket prefix, always creating a new allocation to do so. +func allocPrefixedKey(prefix []byte, key []byte) []byte { + return append(prefix[:len(prefix):len(prefix)], key...) +} + +// creates a new bucket prefix from a parent bucket prefix and the child +// bucket name. +func nestedBucketPrefix(parentPrefix, child []byte) []byte { + prefix := make([]byte, 0, len(parentPrefix)+1+len(child)) + prefix = append(prefix, parentPrefix[:len(parentPrefix)-2]...) + prefix = append(prefix, ',') + prefix = append(prefix, child...) + prefix = append(prefix, "}/"...) + return prefix +} + +// strips bucket prefix from a key. +// panics if the key does not begin with the prefix. +func strippedKey(prefix, key []byte) []byte { + if !bytes.HasPrefix(key, prefix) { + panic(fmt.Sprintf("key %q does not have prefix %q", key, prefix)) // XXX + } + return key[len(prefix):] +} + +// convertErr wraps a driver-specific error with an error code. +func convertErr(err error) error { + if err == nil { + return nil + } + var kind errors.Kind + switch err { + case badger.ErrBannedKey, badger.ErrBlockedWrites, badger.ErrDBClosed, badger.ErrDiscardedTxn, badger.ErrEmptyKey, + badger.ErrEncryptionKeyMismatch, badger.ErrGCInMemoryMode, badger.ErrInvalidDataKeyID, badger.ErrInvalidDump, + badger.ErrInvalidEncryptionKey, badger.ErrInvalidKey, badger.ErrInvalidRequest, badger.ErrManagedTxn, + badger.ErrNamespaceMode, badger.ErrNilCallback, badger.ErrNoRewrite, badger.ErrPlan9NotSupported, + badger.ErrReadOnlyTxn, badger.ErrRejected, badger.ErrThresholdZero, badger.ErrTruncateNeeded, + badger.ErrTxnTooBig, badger.ErrValueLogSize, badger.ErrWindowsNotSupported, badger.ErrZeroBandwidth: + kind = errors.Invalid + case badger.ErrKeyNotFound: + kind = errors.NotExist + case badger.ErrConflict: + kind = errors.IO + } + return errors.E(kind, err) +} + +// transaction represents a database transaction. It can either by read-only or +// read-write and implements the walletdb Tx interfaces. +type transaction struct { + txn *badger.Txn + closed bool +} + +func (tx *transaction) ReadBucket(key []byte) walletdb.ReadBucket { + return tx.ReadWriteBucket(key) +} + +func (tx *transaction) ReadWriteBucket(key []byte) walletdb.ReadWriteBucket { + // XXX: close enough + b, _ := tx.CreateTopLevelBucket(key) + return b +} + +func (tx *transaction) CreateTopLevelBucket(key []byte) (walletdb.ReadWriteBucket, error) { + prefix := topLevelPrefix(key) + b := &bucket{ + prefix: prefix, + txn: tx.txn, + } + return b, nil +} + +func (tx *transaction) DeleteTopLevelBucket(key []byte) error { + // XXX: this is not removing nested buckets. + prefix := topLevelPrefix(key) + opts := badger.DefaultIteratorOptions + opts.Prefix = prefix + iter := tx.txn.NewIterator(opts) + defer iter.Close() + for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() { + item := iter.Item() + key := item.Key() + err := tx.txn.Delete(key) + if err != nil { + return convertErr(err) + } + } + return nil +} + +// Commit commits all changes that have been made through the root bucket and +// all of its sub-buckets to persistent storage. +// +// This function is part of the walletdb.Tx interface implementation. +func (tx *transaction) Commit() error { + if tx.closed { + return convertErr(badger.ErrDiscardedTxn) + } + err := tx.txn.Commit() + tx.closed = true + return convertErr(err) +} + +// Rollback undoes all changes that have been made to the root bucket and all of +// its sub-buckets. +// +// This function is part of the walletdb.Tx interface implementation. +func (tx *transaction) Rollback() error { + tx.txn.Discard() + if tx.closed { + return convertErr(badger.ErrDiscardedTxn) + } + tx.closed = true + return nil +} + +// bucket is an internal type used to represent a collection of key/value pairs +// and implements the walletdb Bucket interfaces. +type bucket struct { + // badger does not implement anything similar to bbolt's buckets, so + // this is done via a prefix on all keys. There is no built in + // namespace separation between nested buckets and keys in the outer + // bucket with the same prefix. + // + // To work around this, we use a "{key1,key2,...}/" prefix. An empty + // entry with special bucket metadata marks the existence of the + // bucket. Nested buckets can have up to two metadata entries: one + // for the nested bucket's full prefix, and another in the parent + // bucket (if any) to signal the bucket's existence during cursor + // iteration. + prefix []byte + txn *badger.Txn +} + +// Enforce bucket implements the walletdb Bucket interfaces. +var _ walletdb.ReadWriteBucket = (*bucket)(nil) + +// NestedReadWriteBucket retrieves a nested bucket with the given key. Returns +// nil if the bucket does not exist. +// +// This function is part of the walletdb.ReadWriteBucket interface implementation. +func (b *bucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBucket { + prefix := nestedBucketPrefix(b.prefix, key) + item, err := b.txn.Get(prefix) + if errors.Is(err, badger.ErrKeyNotFound) { + return nil + } + if err == nil && item.UserMeta() != metaBucket { + return nil + } + nestedBucket := &bucket{ + prefix: prefix, + txn: b.txn, + } + return nestedBucket +} + +func (b *bucket) NestedReadBucket(key []byte) walletdb.ReadBucket { + return b.NestedReadWriteBucket(key) +} + +// CreateBucket creates and returns a new nested bucket with the given key. +// Errors with code Exist if the bucket already exists, and Invalid if the key +// is empty or otherwise invalid for the driver. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) CreateBucket(key []byte) (walletdb.ReadWriteBucket, error) { + if len(key) == 0 { + return nil, convertErr(badger.ErrEmptyKey) + } + prefix := nestedBucketPrefix(b.prefix, key) + _, err := b.txn.Get(prefix) + if !errors.Is(err, badger.ErrKeyNotFound) { + return nil, errors.E(errors.Exist, "CreateBucket: bucket exists") + } + e := badger.NewEntry(prefix, nil) + e.UserMeta = metaBucket + err = b.txn.SetEntry(e) + if err != nil { + return nil, convertErr(err) + } + nestedBucket := &bucket{ + prefix: prefix, + txn: b.txn, + } + return nestedBucket, nil +} + +// CreateBucketIfNotExists creates and returns a new nested bucket with the +// given key if it does not already exist. Errors with code Invalid if the key +// is empty or otherwise invalid for the driver. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) CreateBucketIfNotExists(key []byte) (walletdb.ReadWriteBucket, error) { + if len(key) == 0 { + return nil, convertErr(badger.ErrEmptyKey) + } + prefix := nestedBucketPrefix(b.prefix, key) + _, err := b.txn.Get(prefix) + if errors.Is(err, badger.ErrKeyNotFound) { + e := badger.NewEntry(prefix, nil) + e.UserMeta = metaBucket + err := b.txn.SetEntry(e) + if err != nil { + return nil, convertErr(err) + } + } + nestedBucket := &bucket{ + prefix: prefix, + txn: b.txn, + } + return nestedBucket, nil +} + +// DeleteNestedBucket removes a nested bucket with the given key. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) DeleteNestedBucket(key []byte) error { + // XXX: this is not removing buckets nested within this nested bucket. + if len(key) == 0 { + return convertErr(badger.ErrEmptyKey) + } + prefix := nestedBucketPrefix(b.prefix, key) + opts := badger.DefaultIteratorOptions + opts.Prefix = prefix + iter := b.txn.NewIterator(opts) + defer iter.Close() + var deleted bool + for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() { + item := iter.Item() + key := item.Key() + err := b.txn.Delete(key) + if err != nil { + return convertErr(err) + } + deleted = true + } + if !deleted { + return errors.E(errors.NotExist, "DeletedNestedBucket: nested bucket does not exist") + } + return nil +} + +// ForEach invokes the passed function with every key/value pair in the bucket. +// This includes nested buckets, in which case the value is nil, but it does not +// include the key/value pairs within those nested buckets. +// XXX: above is very much untrue in the current impl. nested buckets will not be +// iterated over at all. +// +// NOTE: The values returned by this function are only valid during a +// transaction. Attempting to access them after a transaction has ended will +// likely result in an access violation. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) ForEach(fn func(k, v []byte) error) error { + opts := badger.DefaultIteratorOptions + opts.Prefix = b.prefix + iter := b.txn.NewIterator(opts) + defer iter.Close() + for iter.Rewind(); iter.ValidForPrefix(b.prefix); iter.Next() { + item := iter.Item() + k := item.Key() + // Ignore metadata for the (non-child) bucket. + if item.UserMeta() == metaBucket && len(k) == len(b.prefix) { + continue + } + k = strippedKey(b.prefix, k) + err := item.Value(func(v []byte) error { + return fn(k, v) + }) + if err != nil { + return convertErr(err) + } + } + return nil +} + +// Put saves the specified key/value pair to the bucket. Keys that do not +// already exist are added and keys that already exist are overwritten. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) Put(key, value []byte) error { + if len(key) == 0 { + return convertErr(badger.ErrEmptyKey) + } + err := b.txn.Set(allocPrefixedKey(b.prefix, key), value) + return convertErr(err) +} + +// Get returns the value for the given key. Returns nil if the key does +// not exist in this bucket (or nested buckets). +// +// NOTE: The value returned by this function is only valid during a +// transaction. Attempting to access it after a transaction has ended +// will likely result in an access violation. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) Get(key []byte) []byte { + item, err := b.txn.Get(reusePrefixedKey(&b.prefix, key)) + if err != nil { + if errors.Is(err, badger.ErrKeyNotFound) { + return nil + } + panic(fmt.Sprintf("badger.Txn.Get: %v", err)) // XXX + } + return unsafeValue(item) +} + +// Delete removes the specified key from the bucket. Deleting a key that does +// not exist does not return an error. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) Delete(key []byte) error { + if len(key) == 0 { + return convertErr(badger.ErrEmptyKey) + } + err := b.txn.Delete(allocPrefixedKey(b.prefix, key)) + return convertErr(err) +} + +// KeyN returns the number of keys and value pairs inside a bucket. +// +// This function is part of the walletdb.ReadBucket interface implementation. +func (b *bucket) KeyN() int { + opts := badger.DefaultIteratorOptions + opts.Prefix = b.prefix + iter := b.txn.NewIterator(opts) + defer iter.Close() + var count int + for iter.Rewind(); iter.ValidForPrefix(b.prefix); iter.Next() { + item := iter.Item() + // Skip the metadata entry for the bucket prefix. + if item.UserMeta() == metaBucket { + continue + } + count++ + } + return count +} + +func (b *bucket) ReadCursor() walletdb.ReadCursor { + return b.ReadWriteCursor() +} + +func (b *bucket) ReverseReadCursor() walletdb.ReadCursor { + return b.ReverseReadWriteCursor() +} + +// ReadWriteCursor returns a new cursor, allowing for iteration over the bucket's +// key/value pairs and nested buckets in forward order. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) ReadWriteCursor() walletdb.ReadWriteCursor { + opts := badger.DefaultIteratorOptions + opts.Prefix = b.prefix + iter := b.txn.NewIterator(opts) + c := &cursor{ + prefix: b.prefix, + txn: b.txn, + iter: iter, + } + return c +} + +// ReadWriteCursor returns a new cursor, allowing for iteration over the bucket's +// key/value pairs and nested buckets in reverse order. +// +// This function is part of the walletdb.Bucket interface implementation. +func (b *bucket) ReverseReadWriteCursor() walletdb.ReadWriteCursor { + opts := badger.DefaultIteratorOptions + opts.Prefix = b.prefix + opts.Reverse = true + iter := b.txn.NewIterator(opts) + c := &cursor{ + prefix: b.prefix, + txn: b.txn, + iter: iter, + } + return c +} + +// cursor represents a cursor over key/value pairs and nested buckets of a +// bucket. +// +// Note that open cursors are not tracked on bucket changes and any +// modifications to the bucket, with the exception of cursor.Delete, invalidate +// the cursor. After invalidation, the cursor must be repositioned, or the keys +// and values returned may be unpredictable. +type cursor struct { + prefix []byte + txn *badger.Txn + iter *badger.Iterator +} + +// Delete removes the current key/value pair the cursor is at without +// invalidating the cursor. +// +// This function is part of the walletdb.Cursor interface implementation. +func (c *cursor) Delete() error { + key := c.iter.Item().Key() + err := c.txn.Delete(key) + return convertErr(err) +} + +// First positions the cursor at the first key/value pair and returns the pair. +// +// This function is part of the walletdb.Cursor interface implementation. +func (c *cursor) First() (key, value []byte) { + c.iter.Rewind() + if !c.iter.ValidForPrefix(c.prefix) { + return nil, nil + } + // Skip the metadata entry for the bucket prefix. + if item := c.iter.Item(); item.UserMeta() == metaBucket && len(item.Key()) == len(c.prefix) { + c.iter.Next() + } + if !c.iter.ValidForPrefix(c.prefix) { + return nil, nil + } + item := c.iter.Item() + return itemKV(c.prefix, item) +} + +// Next moves the cursor one key/value pair forward and returns the new pair. +// +// This function is part of the walletdb.Cursor interface implementation. +func (c *cursor) Next() (key, value []byte) { + c.iter.Next() + if !c.iter.ValidForPrefix(c.prefix) { + return nil, nil + } + // Skip the metadata entry for the bucket prefix. + if item := c.iter.Item(); item.UserMeta() == metaBucket && len(item.Key()) == len(c.prefix) { + c.iter.Next() + } + if !c.iter.ValidForPrefix(c.prefix) { + return nil, nil + } + item := c.iter.Item() + return itemKV(c.prefix, item) +} + +// Seek positions the cursor at the passed seek key. If the key does not exist, +// the cursor is moved to the next key after seek. Returns the new pair. +// +// This function is part of the walletdb.Cursor interface implementation. +func (c *cursor) Seek(seek []byte) (key, value []byte) { + c.iter.Seek(reusePrefixedKey(&c.prefix, seek)) + if !c.iter.ValidForPrefix(c.prefix) { + return nil, nil + } + // Skip the metadata entry for the bucket prefix. + if item := c.iter.Item(); item.UserMeta() == metaBucket && len(item.Key()) == len(c.prefix) { + c.iter.Next() + } + if !c.iter.ValidForPrefix(c.prefix) { + return nil, nil + } + item := c.iter.Item() + k, v := itemKV(c.prefix, item) + return k, v +} + +// Closes the cursor +// +// This function is part of the walletdb.Cursor interface implementation. +func (c *cursor) Close() { + c.iter.Close() +} + +// db represents a collection of namespaces which are persisted and implements +// the walletdb.Db interface. All database access is performed through +// transactions which are obtained through the specific Namespace. +type db struct { + db *badger.DB +} + +// Enforce db implements the walletdb.Db interface. +var _ walletdb.DB = (*db)(nil) + +func (db *db) beginTx(writable bool) (*transaction, error) { + txn := db.db.NewTransaction(writable) + if db.db.IsClosed() { + return nil, convertErr(badger.ErrDBClosed) + } + return &transaction{txn: txn}, nil +} + +func (db *db) BeginReadTx() (walletdb.ReadTx, error) { + return db.beginTx(false) +} + +func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) { + return db.beginTx(true) +} + +// Close cleanly shuts down the database and syncs all data. +// +// This function is part of the walletdb.Db interface implementation. +func (db *db) Close() error { + return convertErr(db.db.Close()) +} + +// dirExists returns whether the file with name exists and is a directory. +func dirExists(name string) bool { + if stat, err := os.Stat(name); err == nil { + return stat.IsDir() + } + return false +} + +// openDB opens the database at the provided path. +func openDB(dbPath string, create bool) (walletdb.DB, error) { + if !create && !dirExists(dbPath) { + return nil, errors.E(errors.NotExist, "missing database directory") + } + + opts := badger.DefaultOptions(dbPath) + opts.ChecksumVerificationMode = options.OnTableAndBlockRead + opts.VerifyValueChecksum = true + opts.Logger = nil // XXX + badgerDB, err := badger.Open(opts) + if err != nil { + return nil, convertErr(err) + } + return &db{badgerDB}, nil +} diff --git a/wallet/internal/badgerdb/doc.go b/wallet/internal/badgerdb/doc.go new file mode 100644 index 000000000..17bac31e7 --- /dev/null +++ b/wallet/internal/badgerdb/doc.go @@ -0,0 +1,26 @@ +// Copyright (c) 2014 The btcsuite developers +// Copyright (c) 2015-2025 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +/* +Package badgerdb implements an instance of walletdb that uses badger for the +backing datastore. + +# Usage + +This package is only a driver to the walletdb package and provides the database +type of "badgerdb". The only parameter the Open and Create functions take is the +database path as a string: + + db, err := walletdb.Open("badgerdb", "path/to/database.db") + if err != nil { + // Handle error + } + + db, err := walletdb.Create("badgerdb", "path/to/database.db") + if err != nil { + // Handle error + } +*/ +package badgerdb diff --git a/wallet/internal/badgerdb/driver.go b/wallet/internal/badgerdb/driver.go new file mode 100644 index 000000000..419ca68b5 --- /dev/null +++ b/wallet/internal/badgerdb/driver.go @@ -0,0 +1,68 @@ +// Copyright (c) 2014 The btcsuite developers +// Copyright (c) 2015-2025 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package badgerdb + +import ( + "fmt" + + "decred.org/dcrwallet/v5/errors" + "decred.org/dcrwallet/v5/wallet/walletdb" +) + +const ( + dbType = "badgerdb" +) + +// parseArgs parses the arguments from the walletdb Open/Create methods. +func parseArgs(funcName string, args ...any) (string, error) { + if len(args) != 1 { + return "", errors.Errorf("invalid arguments to %s.%s -- "+ + "expected database path", dbType, funcName) + } + + dbPath, ok := args[0].(string) + if !ok { + return "", errors.Errorf("first argument to %s.%s is invalid -- "+ + "expected database path string", dbType, funcName) + } + + return dbPath, nil +} + +// openDBDriver is the callback provided during driver registration that opens +// an existing database for use. +func openDBDriver(args ...any) (walletdb.DB, error) { + dbPath, err := parseArgs("Open", args...) + if err != nil { + return nil, err + } + + return openDB(dbPath, false) +} + +// createDBDriver is the callback provided during driver registration that +// creates, initializes, and opens a database for use. +func createDBDriver(args ...any) (walletdb.DB, error) { + dbPath, err := parseArgs("Create", args...) + if err != nil { + return nil, err + } + + return openDB(dbPath, true) +} + +func init() { + // Register the driver. + driver := walletdb.Driver{ + DbType: dbType, + Create: createDBDriver, + Open: openDBDriver, + } + if err := walletdb.RegisterDriver(driver); err != nil { + panic(fmt.Sprintf("Failed to register database driver '%s': %v", + dbType, err)) + } +} diff --git a/wallet/internal/badgerdb/driver_test.go b/wallet/internal/badgerdb/driver_test.go new file mode 100644 index 000000000..5f650debf --- /dev/null +++ b/wallet/internal/badgerdb/driver_test.go @@ -0,0 +1,159 @@ +// Copyright (c) 2014 The btcsuite developers +// Copyright (c) 2015-2025 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +// Test must be updated for API changes. +package badgerdb_test + +import ( + "bytes" + "context" + "os" + "testing" + + "decred.org/dcrwallet/v5/errors" + _ "decred.org/dcrwallet/v5/wallet/internal/badgerdb" + "decred.org/dcrwallet/v5/wallet/walletdb" +) + +// dbType is the database type name for this driver. +const dbType = "badgerdb" + +// TestCreateOpenFail ensures that errors related to creating and opening a +// database are handled properly. +func TestCreateOpenFail(t *testing.T) { + // Ensure that attempting to open a database that doesn't exist returns + // the expected error. + if _, err := walletdb.Open(dbType, "noexist.db"); !errors.Is(err, errors.NotExist) { + t.Errorf("Open: unexpected error: %v", err) + return + } + + // Ensure that attempting to open a database with the wrong number of + // parameters returns the expected error. + wantErr := errors.Errorf("invalid arguments to %s.Open -- expected "+ + "database path", dbType) + if _, err := walletdb.Open(dbType, 1, 2, 3); err.Error() != wantErr.Error() { + t.Errorf("Open: did not receive expected error - got %v, "+ + "want %v", err, wantErr) + return + } + + // Ensure that attempting to open a database with an invalid type for + // the first parameter returns the expected error. + wantErr = errors.Errorf("first argument to %s.Open is invalid -- "+ + "expected database path string", dbType) + if _, err := walletdb.Open(dbType, 1); err.Error() != wantErr.Error() { + t.Errorf("Open: did not receive expected error - got %v, "+ + "want %v", err, wantErr) + return + } + + // Ensure that attempting to create a database with the wrong number of + // parameters returns the expected error. + wantErr = errors.Errorf("invalid arguments to %s.Create -- expected "+ + "database path", dbType) + if _, err := walletdb.Create(dbType, 1, 2, 3); err.Error() != wantErr.Error() { + t.Errorf("Create: did not receive expected error - got %v, "+ + "want %v", err, wantErr) + return + } + + // Ensure that attempting to open a database with an invalid type for + // the first parameter returns the expected error. + wantErr = errors.Errorf("first argument to %s.Create is invalid -- "+ + "expected database path string", dbType) + if _, err := walletdb.Create(dbType, 1); err.Error() != wantErr.Error() { + t.Errorf("Create: did not receive expected error - got %v, "+ + "want %v", err, wantErr) + return + } + + // Ensure operations against a closed database return the expected + // error. + dbPath := "createfail.db" + db, err := walletdb.Create(dbType, dbPath) + if err != nil { + t.Errorf("Create: unexpected error: %v", err) + return + } + defer os.Remove(dbPath) + db.Close() + + if _, err := db.BeginReadTx(); !errors.Is(err, errors.Invalid) { + t.Errorf("BeginReadTx: unexpected error: %v", err) + return + } +} + +// TestPersistence ensures that values stored are still valid after closing and +// reopening the database. +func TestPersistence(t *testing.T) { + ctx := context.Background() + // Create a new database to run tests against. + dbPath := "persistencetest.db" + db, err := walletdb.Create(dbType, dbPath) + if err != nil { + t.Errorf("Failed to create test database (%s) %v", dbType, err) + return + } + defer os.Remove(dbPath) + defer db.Close() + + // Create a bucket and put some values into it so they can be tested + // for existence on re-open. + storeValues := map[string]string{ + "ns1key1": "foo1", + "ns1key2": "foo2", + "ns1key3": "foo3", + } + ns1Key := []byte("ns1") + + err = walletdb.Update(ctx, db, func(tx walletdb.ReadWriteTx) error { + ns1Bkt, err := tx.CreateTopLevelBucket(ns1Key) + if err != nil { + return errors.E(errors.IO, err) + } + + for k, v := range storeValues { + if err := ns1Bkt.Put([]byte(k), []byte(v)); err != nil { + return errors.Errorf("Put: unexpected error: %v", err) + } + } + + return nil + }) + if err != nil { + t.Errorf("ns1 Update: unexpected error: %v", err) + return + } + + // Close and reopen the database to ensure the values persist. + db.Close() + db, err = walletdb.Open(dbType, dbPath) + if err != nil { + t.Errorf("Failed to open test database (%s) %v", dbType, err) + return + } + defer db.Close() + + // Ensure the values previously stored in the bucket still exist + // and are correct. + err = walletdb.View(ctx, db, func(tx walletdb.ReadTx) error { + ns1Bkt := tx.ReadBucket(ns1Key) + for k, v := range storeValues { + val := ns1Bkt.Get([]byte(k)) + if !bytes.Equal([]byte(v), val) { + return errors.Errorf("Get: key '%s' does not "+ + "match expected value - got %s, want %s", + k, string(val), v) + } + } + + return nil + }) + if err != nil { + t.Fatalf("%v", err) + } +} diff --git a/wallet/internal/badgerdb/interface_test.go b/wallet/internal/badgerdb/interface_test.go new file mode 100644 index 000000000..40a2d8075 --- /dev/null +++ b/wallet/internal/badgerdb/interface_test.go @@ -0,0 +1,720 @@ +// Copyright (c) 2014 The btcsuite developers +// Copyright (c) 2015 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +// This file intended to be copied into each backend driver directory. Each +// driver should have their own driver_test.go file which creates a database and +// invokes the testInterface function in this file to ensure the driver properly +// implements the interface. See the bdb backend driver for a working example. +// +// NOTE: When copying this file into the backend driver folder, the package name +// will need to be changed accordingly. + +// Test must be updated for API changes. + +package badgerdb_test + +import ( + "bytes" + "context" + "fmt" + "os" + "testing" + + "decred.org/dcrwallet/v5/errors" + "decred.org/dcrwallet/v5/wallet/walletdb" +) + +// errSubTestFail is used to signal that a sub test returned false. +var errSubTestFail = errors.Errorf("sub test failure") + +// testContext is used to store context information about a running test which +// is passed into helper functions. +type testContext struct { + t *testing.T + db walletdb.DB + bucketDepth int + isWritable bool +} + +// rollbackValues returns a copy of the provided map with all values set to an +// empty string. This is used to test that values are properly rolled back. +func rollbackValues(values map[string]string) map[string]string { + retMap := make(map[string]string, len(values)) + for k := range values { + retMap[k] = "" + } + return retMap +} + +// testGetValues checks that all of the provided key/value pairs can be +// retrieved from the database and the retrieved values match the provided +// values. +func testGetValues(tc *testContext, bucket walletdb.ReadBucket, values map[string]string) bool { + for k, v := range values { + var vBytes []byte + if v != "" { + vBytes = []byte(v) + } + + gotValue := bucket.Get([]byte(k)) + if !bytes.Equal(gotValue, vBytes) { + tc.t.Errorf("Get: unexpected value - got %s, want %s", + gotValue, vBytes) + return false + } + } + + return true +} + +// testPutValues stores all of the provided key/value pairs in the provided +// bucket while checking for errors. +func testPutValues(tc *testContext, bucket walletdb.ReadWriteBucket, values map[string]string) bool { + for k, v := range values { + var vBytes []byte + if v != "" { + vBytes = []byte(v) + } + if err := bucket.Put([]byte(k), vBytes); err != nil { + tc.t.Errorf("Put: unexpected error: %v", err) + return false + } + } + + return true +} + +// testDeleteValues removes all of the provided key/value pairs from the +// provided bucket. +func testDeleteValues(tc *testContext, bucket walletdb.ReadWriteBucket, values map[string]string) bool { + for k := range values { + if err := bucket.Delete([]byte(k)); err != nil { + tc.t.Errorf("Delete: unexpected error: %v", err) + return false + } + } + + return true +} + +// testNestedReadWriteBucket reruns the testReadWriteBucketInterface against a +// nested bucket along with a counter to only test a couple of level deep. +func testNestedReadWriteBucket(tc *testContext, testBucket walletdb.ReadWriteBucket) bool { + // Don't go more than 2 nested level deep. + if tc.bucketDepth > 1 { + return true + } + + tc.bucketDepth++ + defer func() { + tc.bucketDepth-- + }() + + return testReadWriteBucketInterface(tc, testBucket) +} + +// testReadWriteBucketInterface ensures the bucket interface is working +// properly by exercising all of its functions. +func testReadWriteBucketInterface(tc *testContext, bucket walletdb.ReadWriteBucket) bool { + // keyValues holds the keys and values to use when putting + // values into the bucket. + var keyValues = map[string]string{ + "bucketkey1": "foo1", + "bucketkey2": "foo2", + "bucketkey3": "foo3", + } + if !testPutValues(tc, bucket, keyValues) { + return false + } + + if !testGetValues(tc, bucket, keyValues) { + return false + } + + // Iterate all of the keys using ForEach while making sure the + // stored values are the expected values. + keysFound := make(map[string]struct{}, len(keyValues)) + err := bucket.ForEach(func(k, v []byte) error { + kString := string(k) + wantV, ok := keyValues[kString] + if !ok { + return errors.Errorf("ForEach: key '%s' should "+ + "exist", kString) + } + + if !bytes.Equal(v, []byte(wantV)) { + return errors.Errorf("ForEach: value for key '%s' "+ + "does not match - got %s, want %s", + kString, v, wantV) + } + + keysFound[kString] = struct{}{} + return nil + }) + if err != nil { + tc.t.Errorf("%v", err) + return false + } + + // Ensure all keys were iterated. + for k := range keyValues { + if _, ok := keysFound[k]; !ok { + tc.t.Errorf("ForEach: key '%s' was not iterated "+ + "when it should have been", k) + return false + } + } + + // Delete the keys and ensure they were deleted. + if !testDeleteValues(tc, bucket, keyValues) { + return false + } + if !testGetValues(tc, bucket, rollbackValues(keyValues)) { + return false + } + + // Ensure creating a new bucket works as expected. + testBucketName := []byte("testbucket") + testBucket, err := bucket.CreateBucket(testBucketName) + if err != nil { + tc.t.Errorf("CreateBucket: unexpected error: %v", err) + return false + } + if !testNestedReadWriteBucket(tc, testBucket) { + return false + } + + // Ensure creating a bucket that already exists fails with the + // expected error. + if _, err := bucket.CreateBucket(testBucketName); !errors.Is(err, errors.Exist) { + tc.t.Errorf("CreateBucket: unexpected error: %v", err) + return false + } + + // Ensure CreateBucketIfNotExists returns an existing bucket. + testBucket, err = bucket.CreateBucketIfNotExists(testBucketName) + if err != nil { + tc.t.Errorf("CreateBucketIfNotExists: unexpected "+ + "error: %v", err) + return false + } + if !testNestedReadWriteBucket(tc, testBucket) { + return false + } + + // Ensure retrieving and existing bucket works as expected. + testBucket = bucket.NestedReadWriteBucket(testBucketName) + if !testNestedReadWriteBucket(tc, testBucket) { + return false + } + + // Ensure deleting a bucket works as intended. + if err := bucket.DeleteNestedBucket(testBucketName); err != nil { + tc.t.Errorf("DeleteBucket: unexpected error: %v", err) + return false + } + if b := bucket.NestedReadWriteBucket(testBucketName); b != nil { + tc.t.Errorf("DeleteBucket: bucket '%s' still exists", + testBucketName) + return false + } + + // Ensure deleting a bucket that doesn't exist returns the + // expected error. + if err := bucket.DeleteNestedBucket(testBucketName); !errors.Is(err, errors.NotExist) { + tc.t.Errorf("DeleteBucket: unexpected error: %v", err) + return false + } + + // Ensure CreateBucketIfNotExists creates a new bucket when + // it doesn't already exist. + testBucket, err = bucket.CreateBucketIfNotExists(testBucketName) + if err != nil { + tc.t.Errorf("CreateBucketIfNotExists: unexpected error: %v", err) + return false + } + if !testNestedReadWriteBucket(tc, testBucket) { + return false + } + + // Delete the test bucket to avoid leaving it around for future + // calls. + if err := bucket.DeleteNestedBucket(testBucketName); err != nil { + tc.t.Errorf("DeleteBucket: unexpected error: %v", err) + return false + } + if b := bucket.NestedReadWriteBucket(testBucketName); b != nil { + tc.t.Errorf("DeleteBucket: bucket '%s' still exists", + testBucketName) + return false + } + + return true +} + +// testManualTxInterface ensures that manual transactions work as expected. +func testManualTxInterface(tc *testContext, bucketKey []byte) bool { + db := tc.db + + // populateValues tests that populating values works as expected. + // + // When the writable flag is false, a read-only transaction is created, + // standard bucket tests for read-only transactions are performed, and + // the Commit function is checked to ensure it fails as expected. + // + // Otherwise, a read-write transaction is created, the values are + // written, standard bucket tests for read-write transactions are + // performed, and then the transaction is either committed or rolled + // back depending on the flag. + populateValues := func(writable, rollback bool, putValues map[string]string) bool { + var dbtx walletdb.ReadTx + var rootBucket walletdb.ReadBucket + var err error + if writable { + dbtx, err = db.BeginReadWriteTx() + if err != nil { + tc.t.Errorf("BeginReadWriteTx: unexpected error %v", err) + return false + } + rootBucket = dbtx.(walletdb.ReadWriteTx).ReadWriteBucket(bucketKey) + } else { + dbtx, err = db.BeginReadTx() + if err != nil { + tc.t.Errorf("BeginReadTx: unexpected error %v", err) + return false + } + rootBucket = dbtx.ReadBucket(bucketKey) + } + if rootBucket == nil { + tc.t.Errorf("ReadWriteBucket/ReadBucket: unexpected nil root bucket") + _ = dbtx.Rollback() + return false + } + + if writable { + tc.isWritable = writable + if !testReadWriteBucketInterface(tc, rootBucket.(walletdb.ReadWriteBucket)) { + _ = dbtx.Rollback() + return false + } + } + + if !writable { + // Rollback the transaction. + if err := dbtx.Rollback(); err != nil { + tc.t.Errorf("Commit: unexpected error %v", err) + return false + } + } else { + rootBucket := rootBucket.(walletdb.ReadWriteBucket) + if !testPutValues(tc, rootBucket, putValues) { + return false + } + + if rollback { + // Rollback the transaction. + if err := dbtx.Rollback(); err != nil { + tc.t.Errorf("Rollback: unexpected "+ + "error %v", err) + return false + } + } else { + // The commit should succeed. + if err := dbtx.(walletdb.ReadWriteTx).Commit(); err != nil { + tc.t.Errorf("Commit: unexpected error "+ + "%v", err) + return false + } + } + } + + return true + } + + // checkValues starts a read-only transaction and checks that all of + // the key/value pairs specified in the expectedValues parameter match + // what's in the database. + checkValues := func(expectedValues map[string]string) bool { + // Begin another read-only transaction to ensure... + dbtx, err := db.BeginReadTx() + if err != nil { + tc.t.Errorf("BeginReadTx: unexpected error %v", err) + return false + } + + rootBucket := dbtx.ReadBucket(bucketKey) + if rootBucket == nil { + tc.t.Errorf("ReadBucket: unexpected nil root bucket") + _ = dbtx.Rollback() + return false + } + + if !testGetValues(tc, rootBucket, expectedValues) { + _ = dbtx.Rollback() + return false + } + + // Rollback the read-only transaction. + if err := dbtx.Rollback(); err != nil { + tc.t.Errorf("Commit: unexpected error %v", err) + return false + } + + return true + } + + // deleteValues starts a read-write transaction and deletes the keys + // in the passed key/value pairs. + deleteValues := func(values map[string]string) bool { + dbtx, err := db.BeginReadWriteTx() + if err != nil { + tc.t.Errorf("BeginReadWriteTx: unexpected error %v", err) + _ = dbtx.Rollback() + return false + } + + rootBucket := dbtx.ReadWriteBucket(bucketKey) + if rootBucket == nil { + tc.t.Errorf("RootBucket: unexpected nil root bucket") + _ = dbtx.Rollback() + return false + } + + // Delete the keys and ensure they were deleted. + if !testDeleteValues(tc, rootBucket, values) { + _ = dbtx.Rollback() + return false + } + if !testGetValues(tc, rootBucket, rollbackValues(values)) { + _ = dbtx.Rollback() + return false + } + + // Commit the changes and ensure it was successful. + if err := dbtx.Commit(); err != nil { + tc.t.Errorf("Commit: unexpected error %v", err) + return false + } + + return true + } + + // keyValues holds the keys and values to use when putting values + // into a bucket. + var keyValues = map[string]string{ + "umtxkey1": "foo1", + "umtxkey2": "foo2", + "umtxkey3": "foo3", + } + + // Ensure that attempting populating the values using a read-only + // transaction fails as expected. + if !populateValues(false, true, keyValues) { + return false + } + if !checkValues(rollbackValues(keyValues)) { + return false + } + + // Ensure that attempting populating the values using a read-write + // transaction and then rolling it back yields the expected values. + if !populateValues(true, true, keyValues) { + return false + } + if !checkValues(rollbackValues(keyValues)) { + return false + } + + // Ensure that attempting populating the values using a read-write + // transaction and then committing it stores the expected values. + if !populateValues(true, false, keyValues) { + return false + } + if !checkValues(keyValues) { + return false + } + + // Clean up the keys. + if !deleteValues(keyValues) { + return false + } + + return true +} + +// testNamespaceAndTxInterfaces creates a namespace using the provided key and +// tests all facets of it interface as well as transaction and bucket +// interfaces under it. +func testNamespaceAndTxInterfaces(tc *testContext, namespaceKey string) bool { + ctx := context.Background() + namespaceKeyBytes := []byte(namespaceKey) + err := walletdb.Update(ctx, tc.db, func(tx walletdb.ReadWriteTx) error { + _, err := tx.CreateTopLevelBucket(namespaceKeyBytes) + return err + }) + if err != nil { + tc.t.Errorf("CreateTopLevelBucket: unexpected error: %v", err) + return false + } + defer func() { + // Remove the namespace now that the tests are done for it. + err := walletdb.Update(ctx, tc.db, func(tx walletdb.ReadWriteTx) error { + return tx.DeleteTopLevelBucket(namespaceKeyBytes) + }) + if err != nil { + tc.t.Errorf("DeleteTopLevelBucket: unexpected error: %v", err) + return + } + }() + + if !testManualTxInterface(tc, namespaceKeyBytes) { + return false + } + + // keyValues holds the keys and values to use when putting values + // into a bucket. + var keyValues = map[string]string{ + "mtxkey1": "foo1", + "mtxkey2": "foo2", + "mtxkey3": "foo3", + } + + // Test the bucket interface via a managed read-only transaction. + err = walletdb.View(ctx, tc.db, func(tx walletdb.ReadTx) error { + rootBucket := tx.ReadBucket(namespaceKeyBytes) + if rootBucket == nil { + return fmt.Errorf("ReadBucket: unexpected nil root bucket") + } + + return nil + }) + if err != nil { + if !errors.Is(err, errSubTestFail) { + tc.t.Errorf("%v", err) + } + return false + } + + // Test the bucket interface via a managed read-write transaction. + // Also, put a series of values and force a rollback so the following + // code can ensure the values were not stored. + forceRollbackError := fmt.Errorf("force rollback") + err = walletdb.Update(ctx, tc.db, func(tx walletdb.ReadWriteTx) error { + rootBucket := tx.ReadWriteBucket(namespaceKeyBytes) + if rootBucket == nil { + return fmt.Errorf("ReadWriteBucket: unexpected nil root bucket") + } + + tc.isWritable = true + if !testReadWriteBucketInterface(tc, rootBucket) { + return errSubTestFail + } + + if !testPutValues(tc, rootBucket, keyValues) { + return errSubTestFail + } + + // Return an error to force a rollback. + return forceRollbackError + }) + if !errors.Is(err, forceRollbackError) { + if errors.Is(err, errSubTestFail) { + return false + } + + tc.t.Errorf("Update: inner function error not returned - got "+ + "%v, want %v", err, forceRollbackError) + return false + } + + // Ensure the values that should have not been stored due to the forced + // rollback above were not actually stored. + err = walletdb.View(ctx, tc.db, func(tx walletdb.ReadTx) error { + rootBucket := tx.ReadBucket(namespaceKeyBytes) + if rootBucket == nil { + return fmt.Errorf("ReadBucket: unexpected nil root bucket") + } + + if !testGetValues(tc, rootBucket, rollbackValues(keyValues)) { + return errSubTestFail + } + + return nil + }) + if err != nil { + if !errors.Is(err, errSubTestFail) { + tc.t.Errorf("%v", err) + } + return false + } + + // Store a series of values via a managed read-write transaction. + err = walletdb.Update(ctx, tc.db, func(tx walletdb.ReadWriteTx) error { + rootBucket := tx.ReadWriteBucket(namespaceKeyBytes) + if rootBucket == nil { + return fmt.Errorf("ReadWriteBucket: unexpected nil root bucket") + } + + if !testPutValues(tc, rootBucket, keyValues) { + return errSubTestFail + } + + return nil + }) + if err != nil { + if !errors.Is(err, errSubTestFail) { + tc.t.Errorf("%v", err) + } + return false + } + + // Ensure the values stored above were committed as expected. + err = walletdb.View(ctx, tc.db, func(tx walletdb.ReadTx) error { + rootBucket := tx.ReadBucket(namespaceKeyBytes) + if rootBucket == nil { + return fmt.Errorf("ReadBucket: unexpected nil root bucket") + } + + if !testGetValues(tc, rootBucket, keyValues) { + return errSubTestFail + } + + return nil + }) + if err != nil { + if !errors.Is(err, errSubTestFail) { + tc.t.Errorf("%v", err) + } + return false + } + + // Clean up the values stored above in a managed read-write transaction. + err = walletdb.Update(ctx, tc.db, func(tx walletdb.ReadWriteTx) error { + rootBucket := tx.ReadWriteBucket(namespaceKeyBytes) + if rootBucket == nil { + return fmt.Errorf("ReadWriteBucket: unexpected nil root bucket") + } + + if !testDeleteValues(tc, rootBucket, keyValues) { + return errSubTestFail + } + + return nil + }) + if err != nil { + if !errors.Is(err, errSubTestFail) { + tc.t.Errorf("%v", err) + } + return false + } + + return true +} + +// testAdditionalErrors performs some tests for error cases not covered +// elsewhere in the tests and therefore improves negative test coverage. +func testAdditionalErrors(tc *testContext) bool { + ctx := context.Background() + ns3Key := []byte("ns3") + + err := walletdb.Update(ctx, tc.db, func(tx walletdb.ReadWriteTx) error { + // Create a new namespace + rootBucket, err := tx.CreateTopLevelBucket(ns3Key) + if err != nil { + return fmt.Errorf("CreateTopLevelBucket: unexpected error: %v", err) + } + + // Ensure CreateBucket returns the expected error when no bucket + // key is specified. + if _, err := rootBucket.CreateBucket(nil); !errors.Is(err, errors.Invalid) { + return fmt.Errorf("CreateBucket: unexpected error - "+ + "got %v, want %v", err, errors.Invalid) + } + + // Ensure DeleteNestedBucket returns the expected error when no bucket + // key is specified. + if err := rootBucket.DeleteNestedBucket(nil); !errors.Is(err, errors.Invalid) { + return fmt.Errorf("DeleteNestedBucket: unexpected error - "+ + "got %v, want %v", err, errors.Invalid) + } + + // Ensure Put returns the expected error when no key is + // specified. + if err := rootBucket.Put(nil, nil); !errors.Is(err, errors.Invalid) { + return fmt.Errorf("Put: unexpected error - got %v, "+ + "want %v", err, errors.Invalid) + } + + return nil + }) + if err != nil { + if !errors.Is(err, errSubTestFail) { + tc.t.Errorf("%v", err) + } + return false + } + + // Ensure that attempting to rollback or commit a transaction that is + // already closed returns the expected error. + tx, err := tc.db.BeginReadWriteTx() + if err != nil { + tc.t.Errorf("Begin: unexpected error: %v", err) + return false + } + if err := tx.Rollback(); err != nil { + tc.t.Errorf("Rollback: unexpected error: %v", err) + return false + } + if err := tx.Rollback(); !errors.Is(err, errors.Invalid) { + tc.t.Errorf("Rollback: unexpected error - got %v, want %v", err, + errors.Invalid) + return false + } + if err := tx.Commit(); !errors.Is(err, errors.Invalid) { + tc.t.Errorf("Commit: unexpected error - got %v, want %v", err, + errors.Invalid) + return false + } + + return true +} + +// testInterface tests performs tests for the various interfaces of walletdb +// which require state in the database for the given database type. +func testInterface(t *testing.T, db walletdb.DB) { + // Create a test context to pass around. + context := testContext{t: t, db: db} + + // Create a namespace and test the interface for it. + if !testNamespaceAndTxInterfaces(&context, "ns1") { + return + } + + // Create a second namespace and test the interface for it. + if !testNamespaceAndTxInterfaces(&context, "ns2") { + return + } + + // Check a few more error conditions not covered elsewhere. + if !testAdditionalErrors(&context) { + return + } +} + +// TestInterface performs all interfaces tests for this database driver. +func TestInterface(t *testing.T) { + // Create a new database to run tests against. + dbPath := "interfacetest.db" + db, err := walletdb.Create(dbType, dbPath) + if err != nil { + t.Errorf("Failed to create test database (%s) %v", dbType, err) + return + } + defer os.Remove(dbPath) + defer db.Close() + + // Run all of the interface tests against the database. + testInterface(t, db) +} diff --git a/wallet/internal/bdb/db.go b/wallet/internal/bdb/db.go index df5f9f57a..7c6bf5c28 100644 --- a/wallet/internal/bdb/db.go +++ b/wallet/internal/bdb/db.go @@ -6,7 +6,7 @@ package bdb import ( - "io" + "bytes" "os" "decred.org/dcrwallet/v5/errors" @@ -195,12 +195,25 @@ func (b *bucket) ReadCursor() walletdb.ReadCursor { return b.ReadWriteCursor() } +func (b *bucket) ReverseReadCursor() walletdb.ReadCursor { + return b.ReverseReadWriteCursor() +} + // ReadWriteCursor returns a new cursor, allowing for iteration over the bucket's // key/value pairs and nested buckets in forward or backward order. // // This function is part of the walletdb.Bucket interface implementation. func (b *bucket) ReadWriteCursor() walletdb.ReadWriteCursor { - return (*cursor)((*bolt.Bucket)(b).Cursor()) + return &cursor{ + cursor: (*bolt.Bucket)(b).Cursor(), + } +} + +func (b *bucket) ReverseReadWriteCursor() walletdb.ReadWriteCursor { + return &cursor{ + cursor: (*bolt.Bucket)(b).Cursor(), + reverse: true, + } } // cursor represents a cursor over key/value pairs and nested buckets of a @@ -210,42 +223,37 @@ func (b *bucket) ReadWriteCursor() walletdb.ReadWriteCursor { // modifications to the bucket, with the exception of cursor.Delete, invalidate // the cursor. After invalidation, the cursor must be repositioned, or the keys // and values returned may be unpredictable. -type cursor bolt.Cursor +type cursor struct { + cursor *bolt.Cursor + reverse bool +} // Delete removes the current key/value pair the cursor is at without // invalidating the cursor. // // This function is part of the walletdb.Cursor interface implementation. func (c *cursor) Delete() error { - return convertErr((*bolt.Cursor)(c).Delete()) + return convertErr(c.cursor.Delete()) } // First positions the cursor at the first key/value pair and returns the pair. // // This function is part of the walletdb.Cursor interface implementation. func (c *cursor) First() (key, value []byte) { - return (*bolt.Cursor)(c).First() -} - -// Last positions the cursor at the last key/value pair and returns the pair. -// -// This function is part of the walletdb.Cursor interface implementation. -func (c *cursor) Last() (key, value []byte) { - return (*bolt.Cursor)(c).Last() + if c.reverse { + return c.cursor.Last() + } + return c.cursor.First() } // Next moves the cursor one key/value pair forward and returns the new pair. // // This function is part of the walletdb.Cursor interface implementation. func (c *cursor) Next() (key, value []byte) { - return (*bolt.Cursor)(c).Next() -} - -// Prev moves the cursor one key/value pair backward and returns the new pair. -// -// This function is part of the walletdb.Cursor interface implementation. -func (c *cursor) Prev() (key, value []byte) { - return (*bolt.Cursor)(c).Prev() + if c.reverse { + return c.cursor.Prev() + } + return c.cursor.Next() } // Seek positions the cursor at the passed seek key. If the key does not exist, @@ -253,7 +261,11 @@ func (c *cursor) Prev() (key, value []byte) { // // This function is part of the walletdb.Cursor interface implementation. func (c *cursor) Seek(seek []byte) (key, value []byte) { - return (*bolt.Cursor)(c).Seek(seek) + k, v := c.cursor.Seek(seek) + if c.reverse && !bytes.Equal(k, seek) { + k, v = c.cursor.Prev() + } + return k, v } // Closes the cursor @@ -285,16 +297,6 @@ func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) { return db.beginTx(true) } -// Copy writes a copy of the database to the provided writer. This call will -// start a read-only transaction to perform all operations. -// -// This function is part of the walletdb.Db interface implementation. -func (db *db) Copy(w io.Writer) error { - return convertErr((*bolt.DB)(db).View(func(tx *bolt.Tx) error { - return tx.Copy(w) - })) -} - // Close cleanly shuts down the database and syncs all data. // // This function is part of the walletdb.Db interface implementation. diff --git a/wallet/main_test.go b/wallet/main_test.go index eca649c56..597ce1dbc 100644 --- a/wallet/main_test.go +++ b/wallet/main_test.go @@ -12,7 +12,10 @@ import ( "github.com/decred/slog" ) -var logFlag = flag.Bool("log", false, "enable package logger") +var ( + logFlag = flag.Bool("log", false, "enable package logger") + driverFlag = flag.String("dbdriver", "bdb", "database driver (bdb or badgerdb)") +) func TestMain(m *testing.M) { flag.Parse() diff --git a/wallet/setup_test.go b/wallet/setup_test.go index a1cb7d958..03a31c40f 100644 --- a/wallet/setup_test.go +++ b/wallet/setup_test.go @@ -7,8 +7,10 @@ package wallet import ( "context" "os" + "path/filepath" "testing" + _ "decred.org/dcrwallet/v5/wallet/drivers/badgerdb" _ "decred.org/dcrwallet/v5/wallet/drivers/bdb" "decred.org/dcrwallet/v5/wallet/walletdb" "github.com/decred/dcrd/chaincfg/v3" @@ -26,30 +28,25 @@ var basicWalletConfig = Config{ } func testWallet(ctx context.Context, t *testing.T, cfg *Config, seed []byte) *Wallet { - f, err := os.CreateTemp(t.TempDir(), "dcrwallet.testdb") + dbDir, err := os.MkdirTemp(t.TempDir(), "dcrwallet.testdb") if err != nil { t.Fatal(err) } - f.Close() - db, err := walletdb.Create("bdb", f.Name()) + db, err := walletdb.Create(*driverFlag, filepath.Join(dbDir, "wallet.db")) if err != nil { t.Fatal(err) } - rm := func() { - db.Close() - os.Remove(f.Name()) - } err = Create(ctx, opaqueDB{db}, []byte(InsecurePubPassphrase), testPrivPass, seed, cfg.Params) if err != nil { - rm() t.Fatal(err) } cfg.DB = opaqueDB{db} w, err := Open(ctx, cfg) if err != nil { - rm() t.Fatal(err) } - t.Cleanup(rm) + t.Cleanup(func() { + db.Close() + }) return w } diff --git a/wallet/udb/addressmanager_test.go b/wallet/udb/addressmanager_test.go index 0c2dab9cf..40a9e8a9f 100644 --- a/wallet/udb/addressmanager_test.go +++ b/wallet/udb/addressmanager_test.go @@ -9,8 +9,6 @@ import ( "bytes" "context" "fmt" - "os" - "path/filepath" "reflect" "testing" @@ -1214,7 +1212,7 @@ func TestManager(t *testing.T) { } ctx := context.Background() - db, mgr, _, err := cloneDB(ctx, t, "mgr_watching_only.kv") + db, mgr, _, err := cloneDB(ctx, t, "mgr.kv") if err != nil { t.Fatal(err) } @@ -1247,28 +1245,3 @@ func TestManager(t *testing.T) { } testManagerAPI(ctx, tc) } - -func TestMain(m *testing.M) { - testDir, err := os.MkdirTemp("", "udb-") - if err != nil { - fmt.Printf("Unable to create temp directory: %v", err) - os.Exit(1) - } - - emptyDbPath = filepath.Join(testDir, "empty.kv") - teardown := func() { - os.RemoveAll(testDir) - } - - ctx := context.Background() - err = createEmptyDB(ctx) - if err != nil { - fmt.Printf("Unable to create empty test db: %v\n", err) - teardown() - os.Exit(1) - } - - exitCode := m.Run() - teardown() - os.Exit(exitCode) -} diff --git a/wallet/udb/common_test.go b/wallet/udb/common_test.go index 455cf108c..56720b373 100644 --- a/wallet/udb/common_test.go +++ b/wallet/udb/common_test.go @@ -8,10 +8,14 @@ package udb import ( "context" "encoding/hex" + "flag" "fmt" "os" + "path/filepath" "testing" + _ "decred.org/dcrwallet/v5/wallet/internal/badgerdb" + _ "decred.org/dcrwallet/v5/wallet/internal/bdb" "decred.org/dcrwallet/v5/wallet/walletdb" "github.com/decred/dcrd/chaincfg/v3" ) @@ -25,7 +29,11 @@ var ( 0xef, 0x8d, 0x64, 0x15, 0x67, } - emptyDbPath = "" + // rewritten to absolute paths by TestMain + emptyDBDir string + emptyDBPath string + + dbName = "test.db" pubPassphrase = []byte("_DJr{fL4H0O}*-0\n:V1izc)(6BomK") privPassphrase = []byte("81lUHXnOMZ@?XXd7O9xyDIWIbXX-lj") @@ -33,6 +41,10 @@ var ( privPassphrase2 = []byte("~{<]08%6!-?2s<$(8$8:f(5[4/!/{Y") ) +var ( + dbDriver = flag.String("dbdriver", "bdb", "database driver (bdb or badgerdb)") +) + // hexToBytes is a wrapper around hex.DecodeString that panics if there is an // error. It MUST only be used with hard coded values in the tests. func hexToBytes(origHex string) []byte { @@ -45,7 +57,7 @@ func hexToBytes(origHex string) []byte { // createEmptyDB is a helper function for creating an empty wallet db. func createEmptyDB(ctx context.Context) error { - db, err := walletdb.Create("bdb", emptyDbPath) + db, err := walletdb.Create(*dbDriver, emptyDBPath) if err != nil { return err } @@ -68,24 +80,21 @@ func createEmptyDB(ctx context.Context) error { // cloneDB makes a copy of an empty wallet db. It returns a wallet db, address // manager, and the tx store. func cloneDB(ctx context.Context, t *testing.T, cloneName string) (walletdb.DB, *Manager, *Store, error) { - file, err := os.ReadFile(emptyDbPath) + cloneDir := filepath.Join(filepath.Dir(emptyDBDir), cloneName) + err := os.CopyFS(cloneDir, os.DirFS(emptyDBDir)) if err != nil { - return nil, nil, nil, fmt.Errorf("unexpected error: %v", err) + t.Logf("%v %v", cloneDir, emptyDBDir) + return nil, nil, nil, fmt.Errorf("CopyFS unexpected error: %v", err) } - err = os.WriteFile(cloneName, file, 0644) + db, err := walletdb.Open(*dbDriver, filepath.Join(cloneDir, dbName)) if err != nil { - return nil, nil, nil, fmt.Errorf("unexpected error: %v", err) - } - - db, err := walletdb.Open("bdb", cloneName) - if err != nil { - return nil, nil, nil, fmt.Errorf("unexpected error: %v", err) + return nil, nil, nil, fmt.Errorf("walletdb.Open unexpected error: %v", err) } mgr, txStore, err := Open(ctx, db, chaincfg.TestNet3Params(), pubPassphrase) if err != nil { - return nil, nil, nil, fmt.Errorf("unexpected error: %v", err) + return nil, nil, nil, fmt.Errorf("udb.Open unexpected error: %v", err) } t.Cleanup(func() { @@ -95,3 +104,37 @@ func cloneDB(ctx context.Context, t *testing.T, cloneName string) (walletdb.DB, return db, mgr, txStore, err } + +func TestMain(m *testing.M) { + flag.Parse() + + testDir, err := os.MkdirTemp("", "udb-") + if err != nil { + fmt.Printf("Unable to create temp directory: %v", err) + os.Exit(1) + } + + emptyDBDir = filepath.Join(testDir, "empty-db") + err = os.Mkdir(emptyDBDir, 0o777) + if err != nil { + fmt.Printf("Unable to create empty-db directory: %v", err) + os.Exit(1) + } + + emptyDBPath = filepath.Join(emptyDBDir, dbName) + teardown := func() { + os.RemoveAll(testDir) + } + + ctx := context.Background() + err = createEmptyDB(ctx) + if err != nil { + fmt.Printf("Unable to create empty test db: %v\n", err) + teardown() + os.Exit(1) + } + + exitCode := m.Run() + teardown() + os.Exit(exitCode) +} diff --git a/wallet/udb/stakevalidation_test.go b/wallet/udb/stakevalidation_test.go index 08ebe1731..0e072fb83 100644 --- a/wallet/udb/stakevalidation_test.go +++ b/wallet/udb/stakevalidation_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - _ "decred.org/dcrwallet/v5/wallet/internal/bdb" "decred.org/dcrwallet/v5/wallet/walletdb" "github.com/decred/dcrd/dcrutil/v4" gcs2 "github.com/decred/dcrd/gcs/v4" diff --git a/wallet/udb/tx_test.go b/wallet/udb/tx_test.go index 96cd97380..16c744146 100644 --- a/wallet/udb/tx_test.go +++ b/wallet/udb/tx_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - _ "decred.org/dcrwallet/v5/wallet/drivers/bdb" "decred.org/dcrwallet/v5/wallet/walletdb" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/dcrutil/v4" @@ -407,6 +406,7 @@ func TestCoinbases(t *testing.T) { } testMaturity := func(tests []coinbaseTest) error { + t.Helper() for i, tst := range tests { bal, err := s.AccountBalance(dbtx, 0, defaultAccount) if err != nil { diff --git a/wallet/udb/txcommon_test.go b/wallet/udb/txcommon_test.go index ce6a63544..233c68b37 100644 --- a/wallet/udb/txcommon_test.go +++ b/wallet/udb/txcommon_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - _ "decred.org/dcrwallet/v5/wallet/drivers/bdb" "decred.org/dcrwallet/v5/wallet/walletdb" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/chaincfg/v3" diff --git a/wallet/udb/txdb.go b/wallet/udb/txdb.go index 78dd04969..9ceff1ee6 100644 --- a/wallet/udb/txdb.go +++ b/wallet/udb/txdb.go @@ -358,62 +358,31 @@ func makeReadBlockIterator(ns walletdb.ReadBucket, height int32) blockIterator { return blockIterator{c: readCursor{c}, seek: seek} } -// Works just like makeBlockIterator but will initially position the cursor at -// the last k/v pair. Use this with blockIterator.prev. -func makeReverseBlockIterator(ns walletdb.ReadWriteBucket) blockIterator { +// Works just like makeReadBlockIterator but will initially position the +// cursor at the last k/v pair, and next will iterate to previous elements. +func makeReverseBlockIterator(ns walletdb.ReadWriteBucket, height uint32) blockIterator { seek := make([]byte, 4) - byteOrder.PutUint32(seek, ^uint32(0)) - c := ns.NestedReadWriteBucket(bucketBlocks).ReadWriteCursor() + byteOrder.PutUint32(seek, height) + c := ns.NestedReadWriteBucket(bucketBlocks).ReverseReadWriteCursor() return blockIterator{c: c, seek: seek} } -func (it *blockIterator) next() bool { - if it.c == nil { - return false - } - - if it.ck == nil { - it.ck, it.cv = it.c.Seek(it.seek) - } else { - it.ck, it.cv = it.c.Next() - } - if it.ck == nil { - it.c.Close() - it.c = nil - return false - } - - err := readRawBlockRecord(it.ck, it.cv, &it.elem) - if err != nil { - it.c = nil - it.err = err - return false - } - - return true +func makeReverseReadBlockIterator(ns walletdb.ReadBucket, height uint32) blockIterator { + seek := make([]byte, 4) + byteOrder.PutUint32(seek, height) + c := ns.NestedReadBucket(bucketBlocks).ReverseReadCursor() + return blockIterator{c: readCursor{c}, seek: seek} } -func (it *blockIterator) prev() bool { +func (it *blockIterator) next() bool { if it.c == nil { return false } if it.ck == nil { it.ck, it.cv = it.c.Seek(it.seek) - // Seek positions the cursor at the next k/v pair if one with - // this prefix was not found. If this happened (the prefixes - // won't match in this case) move the cursor backward. - // - // This technically does not correct for multiple keys with - // matching prefixes by moving the cursor to the last matching - // key, but this doesn't need to be considered when dealing with - // block records since the key (and seek prefix) is just the - // block height. - if !bytes.HasPrefix(it.ck, it.seek) { - it.ck, it.cv = it.c.Prev() - } } else { - it.ck, it.cv = it.c.Prev() + it.ck, it.cv = it.c.Next() } if it.ck == nil { it.c.Close() diff --git a/wallet/udb/txmined.go b/wallet/udb/txmined.go index de20abcbe..2f9ef90d9 100644 --- a/wallet/udb/txmined.go +++ b/wallet/udb/txmined.go @@ -1900,9 +1900,9 @@ func (s *Store) Rollback(dbtx walletdb.ReadWriteTx, height int32) (map[chainhash var heightsToRemove []int32 removedTxs := make(map[chainhash.Hash][]*wire.MsgTx) - it := makeReverseBlockIterator(ns) + it := makeReverseBlockIterator(ns, ^uint32(0)) defer it.close() - for it.prev() { + for it.next() { b := &it.elem if it.elem.Height < height { break @@ -2364,11 +2364,18 @@ func (s *Store) randomUTXO(dbtx walletdb.ReadTx, skip func(k, v []byte) bool) (k prevFirst := r[32]&1 == 1 c := ns.NestedReadBucket(bucketUnspent).ReadCursor() - k, v = c.Seek(randKey) - iter := c.Next - if prevFirst { - iter = c.Prev - k, v = iter() + rc := ns.NestedReadBucket(bucketUnspent).ReverseReadCursor() + defer c.Close() + defer rc.Close() + var iter func() (k, v []byte) + if !prevFirst { + // Forwards + k, v = c.Seek(randKey) + iter = c.Next + } else { + // Reverse + k, v = rc.Seek(randKey) + iter = rc.Next } var keys [][]byte @@ -2383,19 +2390,21 @@ func (s *Store) randomUTXO(dbtx walletdb.ReadTx, skip func(k, v []byte) bool) (k } // Pick random output when at least one random transaction was found. if len(keys) > 0 { + // Seek key will match exactly, so does not matter which + // cursor is used. k, v = c.Seek(keys[rand.IntN(len(keys))]) - c.Close() return k, v } // Search the opposite direction from the random seek key. if prevFirst { + // Originally reverse, now forwards k, v = c.Seek(randKey) iter = c.Next } else { - c.Seek(randKey) - iter = c.Prev - k, v = iter() + // Originally forwards, now reverse + k, v = rc.Seek(randKey) + iter = rc.Next } for ; k != nil; k, v = iter() { if len(keys) > 0 && !bytes.Equal(keys[0][:32], k[:32]) { @@ -2408,11 +2417,9 @@ func (s *Store) randomUTXO(dbtx walletdb.ReadTx, skip func(k, v []byte) bool) (k } if len(keys) > 0 { k, v = c.Seek(keys[rand.IntN(len(keys))]) - c.Close() return k, v } - c.Close() return nil, nil } diff --git a/wallet/udb/txmined_test.go b/wallet/udb/txmined_test.go index d3815de2b..af4cff53f 100644 --- a/wallet/udb/txmined_test.go +++ b/wallet/udb/txmined_test.go @@ -31,7 +31,7 @@ func randomHash() chainhash.Hash { func TestSetBirthState(t *testing.T) { ctx := context.Background() - db, _, _, err := cloneDB(ctx, t, "mgr_watching_only.kv") + db, _, _, err := cloneDB(ctx, t, "set_birth_state.kv") if err != nil { t.Fatal(err) } diff --git a/wallet/udb/txquery.go b/wallet/udb/txquery.go index b51be0f75..3566d9d5b 100644 --- a/wallet/udb/txquery.go +++ b/wallet/udb/txquery.go @@ -410,10 +410,10 @@ func (s *Store) rangeBlockTransactions(ctx context.Context, ns walletdb.ReadBuck } } else { // Iterate in backwards order, from begin -> end. - blockIter = makeReadBlockIterator(ns, begin) + blockIter = makeReverseReadBlockIterator(ns, uint32(begin)) defer blockIter.close() advance = func(it *blockIterator) bool { - if !it.prev() { + if !it.next() { return false } return end <= it.elem.Height @@ -753,7 +753,7 @@ func (s *Store) RangeBlocks(ns walletdb.ReadBucket, begin, end int32, blockIter = makeReadBlockIterator(ns, begin) defer blockIter.close() advance = func(it *blockIterator) bool { - if !it.prev() { + if !it.next() { return false } return end <= it.elem.Height diff --git a/wallet/walletdb/interface.go b/wallet/walletdb/interface.go index ebd3c896a..3ce057b5f 100644 --- a/wallet/walletdb/interface.go +++ b/wallet/walletdb/interface.go @@ -10,7 +10,6 @@ package walletdb import ( "context" - "io" "runtime/trace" "decred.org/dcrwallet/v5/errors" @@ -84,6 +83,8 @@ type ReadBucket interface { KeyN() int ReadCursor() ReadCursor + + ReverseReadCursor() ReadCursor } // ReadWriteBucket represents a bucket (a hierarchical structure within the @@ -120,33 +121,34 @@ type ReadWriteBucket interface { // attempted against a read-only transaction. Delete(key []byte) error - // Cursor returns a new cursor, allowing for iteration over the bucket's - // key/value pairs and nested buckets in forward or backward order. - // Only one cursor can be opened at a time and should be closed before - // committing or rolling back the transaction. + // ReadWriteCursor returns a new cursor, allowing for iteration over + // the bucket's key/value pairs and nested buckets in forward or + // backward order. Only one cursor can be opened at a time and should + // be closed before committing or rolling back the transaction. ReadWriteCursor() ReadWriteCursor + + // ReverseReadWriteCursor returns a ReadWriteCursor iterating in + // reverse order. + ReverseReadWriteCursor() ReadWriteCursor } // ReadCursor represents a bucket cursor that can be positioned at the start or // end of the bucket's key/value pairs and iterate over pairs in the bucket. // This type is only allowed to perform database read operations. +// +// If the cursor is reversed, all operations occur in the opposite order +// (First returns the last element, Next moves the cursor to the previous, and +// Seek may advance the cursor to the previous element if the seeked key does +// not exist). type ReadCursor interface { // First positions the cursor at the first key/value pair and returns // the pair. First() (key, value []byte) - // Last positions the cursor at the last key/value pair and returns the - // pair. - Last() (key, value []byte) - // Next moves the cursor one key/value pair forward and returns the new // pair. Next() (key, value []byte) - // Prev moves the cursor one key/value pair backward and returns the new - // pair. - Prev() (key, value []byte) - // Seek positions the cursor at the passed seek key. If the key does // not exist, the cursor is moved to the next key after seek. Returns // the new pair. @@ -188,10 +190,6 @@ type DB interface { // BeginReadWriteTx opens a database read+write transaction. BeginReadWriteTx() (ReadWriteTx, error) - // Copy writes a copy of the database to the provided writer. This - // call will start a read-only transaction to perform all operations. - Copy(w io.Writer) error - // Close cleanly shuts down the database and syncs all data. Close() error } diff --git a/walletsetup.go b/walletsetup.go index 20df37f25..05f9465de 100644 --- a/walletsetup.go +++ b/walletsetup.go @@ -19,7 +19,6 @@ import ( "decred.org/dcrwallet/v5/internal/loader" "decred.org/dcrwallet/v5/internal/prompt" "decred.org/dcrwallet/v5/wallet" - _ "decred.org/dcrwallet/v5/wallet/drivers/bdb" "decred.org/dcrwallet/v5/wallet/udb" "decred.org/dcrwallet/v5/walletseed" "github.com/decred/dcrd/chaincfg/v3" @@ -112,9 +111,9 @@ func displaySimnetMiningAddrs(seed []byte, imported bool) error { // to do the initial sync. func createWallet(ctx context.Context, cfg *config) error { dbDir := networkDir(cfg.AppDataDir.Value, activeNet.Params) - loader := loader.NewLoader(activeNet.Params, dbDir, cfg.EnableVoting, - cfg.GapLimit, cfg.WatchLast, cfg.AllowHighFees, cfg.RelayFee.Amount, - cfg.VSPOpts.MaxFee.Amount, cfg.AccountGapLimit, + loader := loader.NewLoader(activeNet.Params, dbDir, cfg.DBDriver, + cfg.EnableVoting, cfg.GapLimit, cfg.WatchLast, cfg.AllowHighFees, + cfg.RelayFee.Amount, cfg.VSPOpts.MaxFee.Amount, cfg.AccountGapLimit, cfg.DisableCoinTypeUpgrades, cfg.MixingEnabled, cfg.ManualTickets, cfg.MixSplitLimit, cfg.dial) @@ -295,7 +294,7 @@ func createSimulationWallet(ctx context.Context, cfg *config) error { fmt.Println("Creating the wallet...") // Create the wallet database backed by bolt db. - db, err := wallet.CreateDB("bdb", dbPath) + db, err := wallet.CreateDB(cfg.DBDriver, dbPath) if err != nil { return err } @@ -347,7 +346,7 @@ func createWatchingOnlyWallet(ctx context.Context, cfg *config) error { fmt.Println("Creating the wallet...") // Create the wallet database backed by bolt db. - db, err := wallet.CreateDB("bdb", dbPath) + db, err := wallet.CreateDB(cfg.DBDriver, dbPath) if err != nil { return err }