diff --git a/btcutil/psbt/creator.go b/btcutil/psbt/creator.go index 58b9a54488..a44b4d3baa 100644 --- a/btcutil/psbt/creator.go +++ b/btcutil/psbt/creator.go @@ -5,6 +5,8 @@ package psbt import ( + "errors" + "github.com/btcsuite/btcd/wire" ) @@ -50,6 +52,9 @@ func New(inputs []*wire.OutPoint, // two lists, and each one must be of length matching the unsigned // transaction; the unknown list can be nil. pInputs := make([]PInput, len(unsignedTx.TxIn)) + for i := range pInputs { + pInputs[i].Sequence = nSequences[i] + } pOutputs := make([]POutput, len(unsignedTx.TxOut)) // This new Psbt is "raw" and contains no key-value fields, so sanity @@ -61,3 +66,62 @@ func New(inputs []*wire.OutPoint, Unknowns: nil, }, nil } + +// NewV2 creates a new, empty Packet that is pre-configured to adhere to the +// BIP-0370 PSBT Version 2 specification. +func NewV2(txVersion uint32, fallbackLocktime uint32, txModifiable uint8) (*Packet, error) { + + if txVersion < 2 { + return nil, errors.New("PSBTv2 requires a transaction version of at least 2") + } + return &Packet{ + Version: 2, + TxVersion: txVersion, + FallbackLocktime: fallbackLocktime, + TxModifiable: txModifiable, + Inputs: nil, + Outputs: nil, + XPubs: nil, + Unknowns: nil, + }, nil +} + +// AddInputV2 appends a new PInput to a Version 2 PSBT, incrementing the +// internal count. It returns an error if the PSBT is not Version 2. +func (p *Packet) AddInputV2(input PInput) error { + if p.Version != 2 { + return errors.New("cannot dynamically add inputs to a non-v2 PSBT") + } + p.Inputs = append(p.Inputs, input) + p.InputCount = uint32(len(p.Inputs)) + return nil +} + +// AddOutputV2 appends a new POutput to a Version 2 PSBT, incrementing the +// internal count. It returns an error if the PSBT is not Version 2. +func (p *Packet) AddOutputV2(output POutput) error { + if p.Version != 2 { + return errors.New("cannot dynamically add outputs to a non-v2 PSBT") + } + p.Outputs = append(p.Outputs, output) + p.OutputCount = uint32(len(p.Outputs)) + return nil +} + +// AddInput adds a new input to a Version 2 PSBT using a standard wire.OutPoint. +func (p *Packet) AddInput(outPoint wire.OutPoint, sequence uint32) error { + return p.AddInputV2(PInput{ + PreviousTxid: outPoint.Hash[:], + OutputIndex: outPoint.Index, + Sequence: sequence, + }) +} + +// AddOutput adds a new output to a Version 2 PSBT using a standard amount and +// script. +func (p *Packet) AddOutput(amount int64, script []byte) error { + return p.AddOutputV2(POutput{ + Amount: amount, + Script: script, + }) +} diff --git a/btcutil/psbt/extractor.go b/btcutil/psbt/extractor.go index 365e2f1bba..dafe4e6b73 100644 --- a/btcutil/psbt/extractor.go +++ b/btcutil/psbt/extractor.go @@ -27,9 +27,12 @@ func Extract(p *Packet) (*wire.MsgTx, error) { return nil, ErrIncompletePSBT } - // First, we'll make a copy of the underlying unsigned transaction (the - // initial template) so we don't mutate it during our activates below. - finalTx := p.UnsignedTx.Copy() + // First, we'll get a fresh copy of the underlying unsigned transaction + // (the initial template) so we don't mutate it during our activates below. + finalTx, err := p.GetUnsignedTx() + if err != nil { + return nil, err + } // For each input, we'll now populate any relevant witness and // sigScript data. diff --git a/btcutil/psbt/finalizer.go b/btcutil/psbt/finalizer.go index b1bf12d131..c15f8276bd 100644 --- a/btcutil/psbt/finalizer.go +++ b/btcutil/psbt/finalizer.go @@ -111,7 +111,17 @@ func isFinalizableLegacyInput(p *Packet, pInput *PInput, inIndex int) bool { // Otherwise, we'll verify that we only have a RedeemScript if the prev // output script is P2SH. - outIndex := p.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + var outIndex uint32 + switch p.Version { + case 0: + if p.UnsignedTx == nil { + return false + } + outIndex = p.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + default: + outIndex = pInput.OutputIndex + } + if txscript.IsPayToScriptHash(pInput.NonWitnessUtxo.TxOut[outIndex].PkScript) { if pInput.RedeemScript == nil { return false @@ -186,7 +196,8 @@ func MaybeFinalize(p *Packet, inIndex int) (bool, error) { // MaybeFinalizeAll attempts to finalize all inputs of the psbt.Packet that are // not already finalized, and returns an error if it fails to do so. func MaybeFinalizeAll(p *Packet) error { - for i := range p.UnsignedTx.TxIn { + numInputs := len(p.Inputs) + for i := 0; i < numInputs; i++ { success, err := MaybeFinalize(p, i) if err != nil || !success { return err @@ -351,6 +362,9 @@ func finalizeNonWitnessInput(p *Packet, inIndex int) error { newInput := NewPsbtInput(pInput.NonWitnessUtxo, nil) newInput.FinalScriptSig = sigScript + // Preserve required PSBTv2 fields and unknowns as mandated by BIP-370 + newInput.CopyInputFields(&pInput) + // Overwrite the entry in the input list at the correct index. Note // that this removes all the other entries in the list for this input // index. @@ -493,6 +507,9 @@ func finalizeWitnessInput(p *Packet, inIndex int) error { newInput.FinalScriptWitness = serializedWitness + // Preserve required PSBTv2 fields and unknowns as mandated by BIP-370 + newInput.CopyInputFields(&pInput) + // Finally, we overwrite the entry in the input list at the correct // index. p.Inputs[inIndex] = *newInput @@ -590,6 +607,9 @@ func finalizeTaprootInput(p *Packet, inIndex int) error { newInput := NewPsbtInput(nil, pInput.WitnessUtxo) newInput.FinalScriptWitness = serializedWitness + // Preserve required PSBTv2 fields and unknowns as mandated by BIP-370 + newInput.CopyInputFields(pInput) + // Finally, we overwrite the entry in the input list at the correct // index. p.Inputs[inIndex] = *newInput diff --git a/btcutil/psbt/partial_input.go b/btcutil/psbt/partial_input.go index 73595d2513..06d632e758 100644 --- a/btcutil/psbt/partial_input.go +++ b/btcutil/psbt/partial_input.go @@ -29,6 +29,11 @@ type PInput struct { TaprootInternalKey []byte TaprootMerkleRoot []byte Unknowns []*Unknown + PreviousTxid []byte + OutputIndex uint32 + TimeLocktime uint32 + HeightLocktime uint32 + Sequence uint32 } // NewPsbtInput creates an instance of PsbtInput given either a nonWitnessUtxo @@ -63,8 +68,53 @@ func (pi *PInput) IsSane() bool { return true } +func (pi *PInput) addUnknown(keyCode byte, keyData, value []byte) error { + return addUnknownField(&pi.Unknowns, keyCode, keyData, value) +} + +// CopyInputFields copies all relevant input fields and unknowns from another PInput. +// This preserves PSBTv2 transaction fields and unknown fields that must be retained +// during finalization as mandated by BIP-370. +// +// For PSBTv0: Only unknowns and sequence are relevant (other fields are zero) +// For PSBTv2: All fields contain important data that must be preserved +// +// It performs a deep copy of the Unknowns slice to ensure the new PInput +// is memory-independent from the source. +// +// BIP-370 Requirement: +// +// "The PSBT_IN_PREVIOUS_TXID, PSBT_IN_OUTPUT_INDEX, PSBT_IN_SEQUENCE, +// PSBT_IN_REQUIRED_TIME_LOCKTIME, and PSBT_IN_REQUIRED_HEIGHT_LOCKTIME +// fields must be retained." +func (pi *PInput) CopyInputFields(from *PInput) { + pi.PreviousTxid = from.PreviousTxid + pi.OutputIndex = from.OutputIndex + pi.Sequence = from.Sequence + pi.TimeLocktime = from.TimeLocktime + pi.HeightLocktime = from.HeightLocktime + + // Deep copy Unknowns (applies to both v0 and v2) + if len(from.Unknowns) > 0 { + pi.Unknowns = make([]*Unknown, len(from.Unknowns)) + for i, u := range from.Unknowns { + pi.Unknowns[i] = &Unknown{ + Key: append([]byte(nil), u.Key...), + Value: append([]byte(nil), u.Value...), + } + } + } +} + // deserialize attempts to deserialize a new PInput from the passed io.Reader. -func (pi *PInput) deserialize(r io.Reader) error { +func (pi *PInput) deserialize(r io.Reader, version uint32) error { + var ( + outputIndexSeen bool + sequenceSeen bool + timeLocktimeSeen bool + heightLockSeen bool + ) + for { keyCode, keyData, err := getKey(r) if err != nil { @@ -80,7 +130,17 @@ func (pi *PInput) deserialize(r io.Reader) error { if err != nil { return err } - + if version == 0 { + switch InputType(keyCode) { + case PreviousTxidInputType, OutputIndexInputType, + SequenceInputType, TimeLocktimeInputType, + HeightLocktimeInputType: + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + continue + } + } switch InputType(keyCode) { case NonWitnessUtxoType: @@ -363,34 +423,111 @@ func (pi *PInput) deserialize(r io.Reader) error { pi.TaprootMerkleRoot = value - default: - // A fall through case for any proprietary types. - keyCodeAndData := append( - []byte{byte(keyCode)}, keyData..., - ) - newUnknown := &Unknown{ - Key: keyCodeAndData, - Value: value, + case PreviousTxidInputType: + if pi.PreviousTxid != nil { + return ErrDuplicateKey + } + if keyData != nil { + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + break + } + if len(value) != 32 { + return ErrInvalidKeyData } - // Duplicate key+keyData are not allowed. - for _, x := range pi.Unknowns { - if bytes.Equal(x.Key, newUnknown.Key) && - bytes.Equal(x.Value, newUnknown.Value) { + if bytes.Equal(value, make([]byte, 32)) { + return ErrInvalidKeyData + } - return ErrDuplicateKey + pi.PreviousTxid = value + + case OutputIndexInputType: + if keyData != nil { + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + break + } + if outputIndexSeen { + return ErrDuplicateKey + } + if len(value) != 4 { + return ErrInvalidKeyData + } + pi.OutputIndex = binary.LittleEndian.Uint32(value) + outputIndexSeen = true + + case TimeLocktimeInputType: + if keyData != nil { + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err } + break + } + if timeLocktimeSeen { + return ErrDuplicateKey + } + if len(value) != 4 { + return ErrInvalidKeyData + } + timeLock := binary.LittleEndian.Uint32(value) + if timeLock < 500000000 { + return ErrInvalidKeyData + } + pi.TimeLocktime = timeLock + timeLocktimeSeen = true + + case HeightLocktimeInputType: + if keyData != nil { + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + break + } + if heightLockSeen { + return ErrDuplicateKey + } + if len(value) != 4 { + return ErrInvalidKeyData + } + heightLock := binary.LittleEndian.Uint32(value) + if heightLock == 0 || heightLock >= 500000000 { + return ErrInvalidKeyData } + pi.HeightLocktime = heightLock + heightLockSeen = true - pi.Unknowns = append(pi.Unknowns, newUnknown) + case SequenceInputType: + if keyData != nil { + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + break + } + if sequenceSeen { + return ErrDuplicateKey + } + if len(value) != 4 { + return ErrInvalidKeyData + } + pi.Sequence = binary.LittleEndian.Uint32(value) + sequenceSeen = true + + default: + if err := pi.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } } + } return nil } // serialize attempts to serialize the target PInput into the passed io.Writer. -func (pi *PInput) serialize(w io.Writer) error { +func (pi *PInput) serialize(w io.Writer, version uint32) error { if !pi.IsSane() { return ErrInvalidPsbtFormat } @@ -423,7 +560,6 @@ func (pi *PInput) serialize(w io.Writer) error { return err } } - if pi.FinalScriptSig == nil && pi.FinalScriptWitness == nil { sort.Sort(PartialSigSorter(pi.PartialSigs)) for _, ps := range pi.PartialSigs { @@ -483,7 +619,90 @@ func (pi *PInput) serialize(w io.Writer) error { return err } } + } + + if pi.FinalScriptSig != nil { + err := serializeKVPairWithType( + w, uint8(FinalScriptSigType), nil, pi.FinalScriptSig, + ) + if err != nil { + return err + } + } + + if pi.FinalScriptWitness != nil { + err := serializeKVPairWithType( + w, uint8(FinalScriptWitnessType), nil, pi.FinalScriptWitness, + ) + if err != nil { + return err + } + } + + // PSBTv2 fields (0x0e-0x12) are serialized here, between + // FinalScriptWitness (0x08) and Taproot fields (0x13+), to maintain + // the ascending key order required by BIP-174. + if version == 2 { + if pi.PreviousTxid != nil { + err := serializeKVPairWithType( + w, uint8(PreviousTxidInputType), nil, pi.PreviousTxid, + ) + if err != nil { + return err + } + } + + var outIndexByte [4]byte + binary.LittleEndian.PutUint32(outIndexByte[:], pi.OutputIndex) + err := serializeKVPairWithType( + w, uint8(OutputIndexInputType), nil, outIndexByte[:], + ) + if err != nil { + return err + } + + if pi.Sequence != wire.MaxTxInSequenceNum { + var seqBytes [4]byte + binary.LittleEndian.PutUint32(seqBytes[:], pi.Sequence) + err := serializeKVPairWithType( + w, uint8(SequenceInputType), nil, seqBytes[:], + ) + if err != nil { + return err + } + } + + if pi.TimeLocktime != 0 { + var timeLockBytes [4]byte + binary.LittleEndian.PutUint32( + timeLockBytes[:], pi.TimeLocktime, + ) + err := serializeKVPairWithType( + w, uint8(TimeLocktimeInputType), nil, + timeLockBytes[:], + ) + if err != nil { + return err + } + } + if pi.HeightLocktime != 0 { + var heightLockBytes [4]byte + binary.LittleEndian.PutUint32( + heightLockBytes[:], pi.HeightLocktime, + ) + err := serializeKVPairWithType( + w, uint8(HeightLocktimeInputType), nil, + heightLockBytes[:], + ) + if err != nil { + return err + } + } + } + + // Taproot fields (0x13-0x18) are only written for non-finalized inputs. + if pi.FinalScriptSig == nil && pi.FinalScriptWitness == nil { if pi.TaprootKeySpendSig != nil { err := serializeKVPairWithType( w, uint8(TaprootKeySpendSignatureType), nil, @@ -573,25 +792,6 @@ func (pi *PInput) serialize(w io.Writer) error { } } } - - if pi.FinalScriptSig != nil { - err := serializeKVPairWithType( - w, uint8(FinalScriptSigType), nil, pi.FinalScriptSig, - ) - if err != nil { - return err - } - } - - if pi.FinalScriptWitness != nil { - err := serializeKVPairWithType( - w, uint8(FinalScriptWitnessType), nil, pi.FinalScriptWitness, - ) - if err != nil { - return err - } - } - // Unknown is a special case; we don't have a key type, only a key and // a value field. for _, kv := range pi.Unknowns { diff --git a/btcutil/psbt/partial_output.go b/btcutil/psbt/partial_output.go index 86e476457d..64ecb14c52 100644 --- a/btcutil/psbt/partial_output.go +++ b/btcutil/psbt/partial_output.go @@ -2,6 +2,7 @@ package psbt import ( "bytes" + "encoding/binary" "io" "sort" @@ -18,6 +19,8 @@ type POutput struct { TaprootTapTree []byte TaprootBip32Derivation []*TaprootBip32Derivation Unknowns []*Unknown + Amount int64 + Script []byte } // NewPsbtOutput creates an instance of PsbtOutput; the three parameters @@ -32,8 +35,12 @@ func NewPsbtOutput(redeemScript []byte, witnessScript []byte, } } +func (po *POutput) addUnknown(keyCode byte, keyData, value []byte) error { + return addUnknownField(&po.Unknowns, keyCode, keyData, value) +} + // deserialize attempts to recode a new POutput from the passed io.Reader. -func (po *POutput) deserialize(r io.Reader) error { +func (po *POutput) deserialize(r io.Reader, version uint32) error { for { keyCode, keyData, err := getKey(r) if err != nil { @@ -50,7 +57,15 @@ func (po *POutput) deserialize(r io.Reader) error { if err != nil { return err } - + if version == 0 { + switch OutputType(keyCode) { + case AmountOutputType, ScriptOutputType: + if err := po.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + continue + } + } switch OutputType(keyCode) { case RedeemScriptOutputType: @@ -143,28 +158,34 @@ func (po *POutput) deserialize(r io.Reader) error { po.TaprootBip32Derivation = append( po.TaprootBip32Derivation, taprootDerivation, ) - - default: - // A fall through case for any proprietary types. - keyCodeAndData := append( - []byte{byte(keyCode)}, keyData..., - ) - newUnknown := &Unknown{ - Key: keyCodeAndData, - Value: value, + case AmountOutputType: + if keyData != nil { + if err := po.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } + break } + if len(value) != 8 { + return ErrInvalidKeyData + } + // BIP-370: 64-bit signed little endian integer. + po.Amount = int64(binary.LittleEndian.Uint64(value)) - // Duplicate key+keyData are not allowed. - for _, x := range po.Unknowns { - if bytes.Equal(x.Key, newUnknown.Key) && - bytes.Equal(x.Value, newUnknown.Value) { - - return ErrDuplicateKey + case ScriptOutputType: + if keyData != nil { + if err := po.addUnknown(byte(keyCode), keyData, value); err != nil { + return err } + break } + po.Script = value - po.Unknowns = append(po.Unknowns, newUnknown) + default: + if err := po.addUnknown(byte(keyCode), keyData, value); err != nil { + return err + } } + } return nil @@ -172,7 +193,7 @@ func (po *POutput) deserialize(r io.Reader) error { // serialize attempts to write out the target POutput into the passed // io.Writer. -func (po *POutput) serialize(w io.Writer) error { +func (po *POutput) serialize(w io.Writer, version uint32) error { if po.RedeemScript != nil { err := serializeKVPairWithType( w, uint8(RedeemScriptOutputType), nil, po.RedeemScript, @@ -204,6 +225,21 @@ func (po *POutput) serialize(w io.Writer) error { return err } } + if version == 2 { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(po.Amount)) + err := serializeKVPairWithType(w, uint8(AmountOutputType), nil, buf[:]) + if err != nil { + return err + } + + if po.Script != nil { + err := serializeKVPairWithType(w, uint8(ScriptOutputType), nil, po.Script) + if err != nil { + return err + } + } + } if po.TaprootInternalKey != nil { err := serializeKVPairWithType( diff --git a/btcutil/psbt/psbt.go b/btcutil/psbt/psbt.go index 5249aad4e1..6fa8f6c313 100644 --- a/btcutil/psbt/psbt.go +++ b/btcutil/psbt/psbt.go @@ -10,10 +10,12 @@ package psbt import ( "bytes" "encoding/base64" + "encoding/binary" "errors" "io" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" ) @@ -144,6 +146,27 @@ type Packet struct { // Unknowns are the set of custom types (global only) within this PSBT. Unknowns []*Unknown + + // Version is the PSBT packet version (0 for BIP-174, 2 for BIP-370). + Version uint32 + + // FallbackLocktime is the transaction locktime to use if no input-specific + // locktime constraints exist (PSBTv2 only). + FallbackLocktime uint32 + + // InputCount is the number of inputs in this PSBT (PSBTv2 only). + InputCount uint32 + + // OutputCount is the number of outputs in this PSBT (PSBTv2 only). + OutputCount uint32 + + // TxVersion is the Bitcoin transaction version for the constructed + // transaction (PSBTv2 only). + TxVersion uint32 + + // TxModifiable is a bitfield indicating which parts of the transaction + // can be modified by subsequent updaters (PSBTv2 only). + TxModifiable uint8 } // validateUnsignedTx returns true if the transaction is unsigned. Note that @@ -159,6 +182,129 @@ func validateUnsignedTX(tx *wire.MsgTx) bool { return true } +// DetermineLockTime implements the BIP-370 "Determining Lock Time" algorithm. +// Per BIP-370: "the field chosen is the one which is supported by all of the inputs +// which specify a locktime in either of those fields." +// +// Algorithm: +// 1. If no inputs have locktime constraints, use PSBT_GLOBAL_FALLBACK_LOCKTIME +// 2. Find the locktime type that ALL constrained inputs can satisfy: +// - Inputs with only TimeLocktime can ONLY be satisfied by time-based locks +// - Inputs with only HeightLocktime can ONLY be satisfied by height-based locks +// - Inputs with both (or neither) can be satisfied by either type +// +// 3. If both types are supported, BIP-370 mandates selecting height-based +// 4. Return the maximum value of the selected type +// +// Reference: https://github.com/bitcoin/bips/blob/master/bip-0370.mediawiki#determining-lock-time +func (p *Packet) DetermineLockTime() (uint32, error) { + var maxTime, maxHeight uint32 + timeSupported := true + heightSupported := true + hasAnyInputLocktime := false + + for _, pIn := range p.Inputs { + hasTimeReq := pIn.TimeLocktime != 0 + hasHeightReq := pIn.HeightLocktime != 0 + + if !hasTimeReq && !hasHeightReq { + continue + } + hasAnyInputLocktime = true + + // Update maximums + if pIn.TimeLocktime > maxTime { + maxTime = pIn.TimeLocktime + } + if pIn.HeightLocktime > maxHeight { + maxHeight = pIn.HeightLocktime + } + + // An input with only PSBT_IN_REQUIRED_TIME_LOCKTIME (0x11) cannot + // be satisfied by a height-based lock; mark height as unsupported. + if hasTimeReq && !hasHeightReq { + heightSupported = false + } + // An input with only PSBT_IN_REQUIRED_HEIGHT_LOCKTIME (0x12) cannot + // be satisfied by a time-based lock; mark time as unsupported. + if hasHeightReq && !hasTimeReq { + timeSupported = false + } + } + + // 1. Fallback Case: No inputs specified constraints + if !hasAnyInputLocktime { + if p.FallbackLocktime != 0 { + return p.FallbackLocktime, nil + } + return 0, nil + } + + // 2. Conflict Case: One input requires Time, another requires Height + if !timeSupported && !heightSupported { + return 0, ErrInvalidPsbtFormat + } + + // 3. Selection Case: BIP-370 tie-breaker mandates height-based when both supported + // "If a PSBT has both types of locktimes possible... then locktime determined + // by looking at the PSBT_IN_REQUIRED_HEIGHT_LOCKTIME fields of the inputs must be chosen" + if heightSupported { + return maxHeight, nil + } + + // 4. Otherwise, use Time + return maxTime, nil +} + +// GetUnsignedTx returns a copy of the underlying unsigned transaction for this +// PSBT. For version 0 PSBTs, this is a copy of the parsed unsigned transaction. +// For version 2 PSBTs, it dynamically constructs the transaction from the +// individual parsing fields per BIP-0370. +func (p *Packet) GetUnsignedTx() (*wire.MsgTx, error) { + if p.Version == 0 { + if p.UnsignedTx == nil { + return nil, ErrInvalidPsbtFormat + } + return p.UnsignedTx.Copy(), nil + } + + if p.Version != 2 { + return nil, ErrInvalidPsbtFormat + } + + tx := wire.NewMsgTx(int32(p.TxVersion)) + + for _, pIn := range p.Inputs { + if pIn.PreviousTxid == nil { + return nil, ErrInvalidPsbtFormat + } + hash, err := chainhash.NewHash(pIn.PreviousTxid) + if err != nil { + return nil, err + } + + outPoint := wire.NewOutPoint(hash, pIn.OutputIndex) + txIn := wire.NewTxIn(outPoint, nil, nil) + txIn.Sequence = pIn.Sequence + + tx.AddTxIn(txIn) + + } + + for _, pOut := range p.Outputs { + txOut := wire.NewTxOut(pOut.Amount, pOut.Script) + tx.AddTxOut(txOut) + } + + lockTime, err := p.DetermineLockTime() + if err != nil { + return nil, err + } + tx.LockTime = lockTime + + return tx, nil +} + // NewFromUnsignedTx creates a new Psbt struct, without any signatures (i.e. // only the global section is non-empty) using the passed unsigned transaction. func NewFromUnsignedTx(tx *wire.MsgTx) (*Packet, error) { @@ -168,6 +314,12 @@ func NewFromUnsignedTx(tx *wire.MsgTx) (*Packet, error) { inSlice := make([]PInput, len(tx.TxIn)) outSlice := make([]POutput, len(tx.TxOut)) + for i, txin := range tx.TxIn { + inSlice[i].PreviousTxid = txin.PreviousOutPoint.Hash[:] + inSlice[i].OutputIndex = txin.PreviousOutPoint.Index + inSlice[i].Sequence = txin.Sequence + } + xPubSlice := make([]XPub, 0) unknownSlice := make([]*Unknown, 0) @@ -206,49 +358,30 @@ func NewFromRawBytes(r io.Reader, b64 bool) (*Packet, error) { return nil, ErrInvalidMagicBytes } - // Next we parse the GLOBAL section. There is currently only 1 known - // key type, UnsignedTx. We insist this exists first; unknowns are - // allowed, but only after. - keyCode, keyData, err := getKey(r) - if err != nil { - return nil, err - } - if GlobalType(keyCode) != UnsignedTxType || keyData != nil { - return nil, ErrInvalidPsbtFormat - } - - // Now that we've verified the global type is present, we'll decode it - // into a proper unsigned transaction, and validate it. - value, err := wire.ReadVarBytes( - r, 0, MaxPsbtValueLength, "PSBT value", - ) - if err != nil { - return nil, err - } - msgTx := wire.NewMsgTx(2) - - // BIP-0174 states: "The transaction must be in the old serialization - // format (without witnesses)." - err = msgTx.DeserializeNoWitness(bytes.NewReader(value)) - if err != nil { - return nil, err - } - if !validateUnsignedTX(msgTx) { - return nil, ErrInvalidRawTxSigned - } - - // Next we parse any unknowns that may be present, making sure that we - // break at the separator. + // Next we parse the GLOBAL section. var ( - xPubSlice []XPub - unknownSlice []*Unknown + xPubSlice []XPub + unknownSlice []*Unknown + msgTx *wire.MsgTx + version uint32 + fallbackLocktime uint32 + inputCount uint32 + outputCount uint32 + txVersion uint32 + txModifiable uint8 + txVersionSeen bool + inputCountSeen bool + outputCountSeen bool + fallbackLocktimeSeen bool + txModifiableSeen bool ) + for { - keyint, keydata, err := getKey(r) + keyCode, keyData, err := getKey(r) if err != nil { return nil, ErrInvalidPsbtFormat } - if keyint == -1 { + if keyCode == -1 { break } @@ -259,26 +392,124 @@ func NewFromRawBytes(r io.Reader, b64 bool) (*Packet, error) { return nil, err } - switch GlobalType(keyint) { - case XPubType: - xPub, err := ReadXPub(keydata, value) + isUnknown := false + + switch GlobalType(keyCode) { + case UnsignedTxType: + if keyData != nil { + isUnknown = true + break + } + if msgTx != nil { + return nil, ErrDuplicateKey + } + msgTx = wire.NewMsgTx(2) + err = msgTx.DeserializeNoWitness(bytes.NewReader(value)) if err != nil { return nil, err } + if !validateUnsignedTX(msgTx) { + return nil, ErrInvalidRawTxSigned + } - // Duplicate keys are not allowed + case XPubType: + xPub, err := ReadXPub(keyData, value) + if err != nil { + return nil, err + } for _, x := range xPubSlice { if bytes.Equal(x.ExtendedKey, keyData) { return nil, ErrDuplicateKey } } - xPubSlice = append(xPubSlice, *xPub) + case VersionType: + if !isSaneKey(keyData, &isUnknown) { + break + } + if version != 0 { + return nil, ErrDuplicateKey + } + if len(value) != 4 { + return nil, ErrInvalidKeyData + } + version = binary.LittleEndian.Uint32(value) + + case TxVersionGlobalType: + if !isSaneKey(keyData, &isUnknown) { + break + } + if txVersionSeen { + return nil, ErrDuplicateKey + } + if len(value) != 4 { + return nil, ErrInvalidKeyData + } + txVersion = binary.LittleEndian.Uint32(value) + txVersionSeen = true + + case FallbackLocktimeGlobalType: + if !isSaneKey(keyData, &isUnknown) { + break + } + if fallbackLocktimeSeen { + return nil, ErrDuplicateKey + } + if len(value) != 4 { + return nil, ErrInvalidKeyData + } + fallbackLocktime = binary.LittleEndian.Uint32(value) + fallbackLocktimeSeen = true + + case InputCountGlobalType: + if !isSaneKey(keyData, &isUnknown) { + break + } + if inputCountSeen { + return nil, ErrDuplicateKey + } + if len(value) > 8 { + return nil, ErrInvalidKeyData + } + num, _ := wire.ReadVarInt(bytes.NewReader(value), 0) + inputCount = uint32(num) + inputCountSeen = true + + case OutputCountGlobalType: + if !isSaneKey(keyData, &isUnknown) { + break + } + if outputCountSeen { + return nil, ErrDuplicateKey + } + if len(value) > 8 { + return nil, ErrInvalidKeyData + } + num, _ := wire.ReadVarInt(bytes.NewReader(value), 0) + outputCount = uint32(num) + outputCountSeen = true + + case TxModifiableGlobalType: + if !isSaneKey(keyData, &isUnknown) { + break + } + if txModifiableSeen { + return nil, ErrDuplicateKey + } + if len(value) != 1 { + return nil, ErrInvalidKeyData + } + txModifiable = value[0] + txModifiableSeen = true + default: - keyintanddata := []byte{byte(keyint)} - keyintanddata = append(keyintanddata, keydata...) + isUnknown = true + } + if isUnknown { + keyintanddata := []byte{byte(keyCode)} + keyintanddata = append(keyintanddata, keyData...) newUnknown := &Unknown{ Key: keyintanddata, Value: value, @@ -287,42 +518,79 @@ func NewFromRawBytes(r io.Reader, b64 bool) (*Packet, error) { } } + switch version { + case 0: + if msgTx == nil { + return nil, ErrInvalidPsbtFormat + } + + case 2: + if msgTx != nil { + return nil, ErrInvalidPsbtFormat + } + if !txVersionSeen || !inputCountSeen || !outputCountSeen { + return nil, ErrInvalidPsbtFormat + } + + default: + return nil, ErrInvalidPsbtFormat + } + + var inCount, outCount int + if version == 0 { + inCount = len(msgTx.TxIn) + outCount = len(msgTx.TxOut) + } else { + inCount = int(inputCount) + outCount = int(outputCount) + } + // Next we parse the INPUT section. - inSlice := make([]PInput, len(msgTx.TxIn)) - for i := range msgTx.TxIn { - input := PInput{} - err = input.deserialize(r) + inSlice := make([]PInput, inCount) + for i := 0; i < inCount; i++ { + input := PInput{Sequence: wire.MaxTxInSequenceNum} + + err := input.deserialize(r, version) if err != nil { return nil, err } - inSlice[i] = input } // Next we parse the OUTPUT section. - outSlice := make([]POutput, len(msgTx.TxOut)) - for i := range msgTx.TxOut { + outSlice := make([]POutput, outCount) + for i := 0; i < outCount; i++ { output := POutput{} - err = output.deserialize(r) + err := output.deserialize(r, version) if err != nil { return nil, err } - outSlice[i] = output } + if version == 2 && txVersion < 2 { + return nil, ErrInvalidPsbtFormat + } + // Populate the new Packet object. newPsbt := Packet{ - UnsignedTx: msgTx, - Inputs: inSlice, - Outputs: outSlice, - XPubs: xPubSlice, - Unknowns: unknownSlice, + UnsignedTx: msgTx, + Inputs: inSlice, + Outputs: outSlice, + XPubs: xPubSlice, + Unknowns: unknownSlice, + Version: version, + FallbackLocktime: fallbackLocktime, + InputCount: inputCount, + OutputCount: outputCount, + TxVersion: txVersion, + TxModifiable: txModifiable, } // Extended sanity checking is applied here to make sure the // externally-passed Packet follows all the rules. - if err = newPsbt.SanityCheck(); err != nil { + err := newPsbt.SanityCheck() + if err != nil { return nil, err } @@ -338,35 +606,115 @@ func (p *Packet) Serialize(w io.Writer) error { return err } - // Next we prep to write out the unsigned transaction by first - // serializing it into an intermediate buffer. - serializedTx := bytes.NewBuffer( - make([]byte, 0, p.UnsignedTx.SerializeSize()), - ) - if err := p.UnsignedTx.SerializeNoWitness(serializedTx); err != nil { - return err - } - - // Now that we have the serialized transaction, we'll write it out to - // the proper global type. - err := serializeKVPairWithType( - w, uint8(UnsignedTxType), nil, serializedTx.Bytes(), - ) - if err != nil { - return err - } - - // Serialize the global xPubs. - for _, xPub := range p.XPubs { - pathBytes := SerializeBIP32Derivation( - xPub.MasterKeyFingerprint, xPub.Bip32Path, + switch p.Version { + case 0: + // Next we prep to write out the unsigned transaction by first + // serializing it into an intermediate buffer. + if p.UnsignedTx == nil { + return ErrInvalidPsbtFormat + } + serializedTx := bytes.NewBuffer( + make([]byte, 0, p.UnsignedTx.SerializeSize()), ) + if err := p.UnsignedTx.SerializeNoWitness(serializedTx); err != nil { + return err + } + // Now that we have the serialized transaction, we'll write it out to + // the proper global type. + // Key 0x00: UnsignedTxType err := serializeKVPairWithType( - w, uint8(XPubType), xPub.ExtendedKey, pathBytes, + w, uint8(UnsignedTxType), nil, serializedTx.Bytes(), ) if err != nil { return err } + + // Serialize the global xPubs. + // Key 0x01: XPubType + for _, xPub := range p.XPubs { + pathBytes := SerializeBIP32Derivation( + xPub.MasterKeyFingerprint, xPub.Bip32Path, + ) + err := serializeKVPairWithType( + w, uint8(XPubType), xPub.ExtendedKey, pathBytes, + ) + if err != nil { + return err + } + } + + case 2: + // Serialize the global xPubs. + // Key 0x01: XPubType + for _, xPub := range p.XPubs { + pathBytes := SerializeBIP32Derivation( + xPub.MasterKeyFingerprint, xPub.Bip32Path, + ) + err := serializeKVPairWithType( + w, uint8(XPubType), xPub.ExtendedKey, pathBytes, + ) + if err != nil { + return err + } + } + + var buf [4]byte + + // Key 0x02: TxVersion + binary.LittleEndian.PutUint32(buf[:], p.TxVersion) + err := serializeKVPairWithType(w, uint8(TxVersionGlobalType), nil, buf[:]) + if err != nil { + return err + } + + // Key 0x03: FallbackLocktime + if p.FallbackLocktime != 0 { + binary.LittleEndian.PutUint32(buf[:], p.FallbackLocktime) + err = serializeKVPairWithType(w, uint8(FallbackLocktimeGlobalType), nil, buf[:]) + if err != nil { + return err + } + } + + // Key 0x04: InputCount + // Input and Output counts are compact size uints + var countBuf bytes.Buffer + err = wire.WriteVarInt(&countBuf, 0, uint64(p.InputCount)) + if err != nil { + return err + } + err = serializeKVPairWithType(w, uint8(InputCountGlobalType), nil, countBuf.Bytes()) + if err != nil { + return err + } + + // Key 0x05: OutputCount + countBuf.Reset() + err = wire.WriteVarInt(&countBuf, 0, uint64(p.OutputCount)) + if err != nil { + return err + } + err = serializeKVPairWithType(w, uint8(OutputCountGlobalType), nil, countBuf.Bytes()) + if err != nil { + return err + } + + // Key 0x06: TxModifiable + if p.TxModifiable != 0 { + err = serializeKVPairWithType(w, uint8(TxModifiableGlobalType), nil, []byte{p.TxModifiable}) + if err != nil { + return err + } + } + + // Key 0xfb: PSBT Version + binary.LittleEndian.PutUint32(buf[:], 2) + if err := serializeKVPairWithType(w, uint8(VersionType), nil, buf[:]); err != nil { + return err + } + + default: + return ErrInvalidPsbtFormat } // Unknown is a special case; we don't have a key type, only a key and @@ -386,7 +734,7 @@ func (p *Packet) Serialize(w io.Writer) error { } for _, pInput := range p.Inputs { - err := pInput.serialize(w) + err := pInput.serialize(w, p.Version) if err != nil { return err } @@ -397,7 +745,7 @@ func (p *Packet) Serialize(w io.Writer) error { } for _, pOutput := range p.Outputs { - err := pOutput.serialize(w) + err := pOutput.serialize(w, p.Version) if err != nil { return err } @@ -426,7 +774,7 @@ func (p *Packet) B64Encode() (string, error) { // whether the final extraction to a network serialized signed // transaction will be possible. func (p *Packet) IsComplete() bool { - for i := 0; i < len(p.UnsignedTx.TxIn); i++ { + for i := 0; i < len(p.Inputs); i++ { if !isFinalized(p, i) { return false } @@ -434,11 +782,23 @@ func (p *Packet) IsComplete() bool { return true } -// SanityCheck checks conditions on a PSBT to ensure that it obeys the -// rules of BIP174, and returns true if so, false if not. +// SanityCheck checks conditions on a PSBT to ensure that it obeys the rules of +// BIP174 and BIP0370, and returns an error if not. func (p *Packet) SanityCheck() error { - if !validateUnsignedTX(p.UnsignedTx) { - return ErrInvalidRawTxSigned + switch p.Version { + case 0: + if p.UnsignedTx == nil { + return ErrInvalidPsbtFormat + } + if !validateUnsignedTX(p.UnsignedTx) { + return ErrInvalidRawTxSigned + } + case 2: + if p.UnsignedTx != nil { + return ErrInvalidPsbtFormat + } + default: + return ErrInvalidPsbtFormat } for _, tin := range p.Inputs { @@ -459,10 +819,50 @@ func (p *Packet) GetTxFee() (btcutil.Amount, error) { } var sumOutputs int64 - for _, txOut := range p.UnsignedTx.TxOut { - sumOutputs += txOut.Value + for _, pout := range p.Outputs { + sumOutputs += pout.Amount + } + // fallback for v0 PSBTs, which don't have explicit input UTXO information; + if p.Version < 2 { + sumOutputs = 0 + for _, txOut := range p.UnsignedTx.TxOut { + sumOutputs += txOut.Value + } } fee := sumInputs - sumOutputs return btcutil.Amount(fee), nil } + +// isSaneKey is a helper function that checks if a key that is expected to have +// no extra key data actually has some. If it does, it marks the key as unknown +// so that it can be processed as an unknown field rather than causing a +// validation error. +func isSaneKey(keyData []byte, isUnknown *bool) bool { + if keyData != nil { + *isUnknown = true + return false + } + return true +} + +// addUnknownField adds an unknown key-value pair to a slice after checking for duplicates. +// This function is used by both PInput and POutput to handle unknown fields consistently. +func addUnknownField(unknowns *[]*Unknown, keyCode byte, keyData, value []byte) error { + keyCodeAndData := append([]byte{keyCode}, keyData...) + newUnknown := &Unknown{ + Key: keyCodeAndData, + Value: value, + } + + // Duplicate key+keyData combinations are not allowed (per PSBT spec) + for _, x := range *unknowns { + if bytes.Equal(x.Key, newUnknown.Key) && + bytes.Equal(x.Value, newUnknown.Value) { + return ErrDuplicateKey + } + } + + *unknowns = append(*unknowns, newUnknown) + return nil +} diff --git a/btcutil/psbt/psbt_test.go b/btcutil/psbt/psbt_test.go index 0dfa44c56c..13bd73970d 100644 --- a/btcutil/psbt/psbt_test.go +++ b/btcutil/psbt/psbt_test.go @@ -1690,3 +1690,248 @@ func TestUnknowns(t *testing.T) { require.NoError(t, err) require.Equal(t, packetWithUnknowns, encoded) } + +// TestPsbtV2LifeCycle ensures that the full lifecycle of a PSBTv2 (creating, +// constructing, serializing, and extracting) works as expected. +func TestPsbtV2LifeCycle(t *testing.T) { + + // 1. Create a new V2 PSBT. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + // 2. Add an input with a specific sequence (e.g., 0 for RBF). + txid, _ := chainhash.NewHashFromStr("0102030405060708091011121314151617181920212223242526272829303132") + + outPoint := wire.NewOutPoint(txid, 1) + err = p.AddInput(*outPoint, 0) + require.NoError(t, err) + + // 3. Add an output. + script, _ := hex.DecodeString("76a914b6bc2c0ee5655a843d79afedd0ccc3f7dd64340988ac") + err = p.AddOutput(100000000, script) + require.NoError(t, err) + + // 4. Serialize and Parse back. + var b bytes.Buffer + err = p.Serialize(&b) + require.NoError(t, err) + + p2, err := NewFromRawBytes(&b, false) + require.NoError(t, err) + + // 5. Verify fields survived the round-trip. + require.Equal(t, uint32(2), p2.TxVersion) + require.Equal(t, uint32(1), p2.InputCount) + require.Equal(t, txid[:], p2.Inputs[0].PreviousTxid) + require.Equal(t, uint32(0), p2.Inputs[0].Sequence) + require.Equal(t, int64(100000000), p2.Outputs[0].Amount) + + // 6. Extract the transaction and verify it works. + // (Note: For extraction to work, we need to bypass the IsComplete check + // or finalize it. Since we are testing construction, we can test + // GetUnsignedTx directly). + msgTx, err := p2.GetUnsignedTx() + require.NoError(t, err) + require.Equal(t, int32(2), msgTx.Version) + require.Equal(t, uint32(0), msgTx.TxIn[0].Sequence) + require.Equal(t, int64(100000000), msgTx.TxOut[0].Value) +} + +// TestPsbtV2Validation verifies that PSBTv2 packets are validated correctly +// for strict versioning rules and mandatory field combinations. +func TestPsbtV2Validation(t *testing.T) { + t.Run("V2 cannot have global UnsignedTx", func(t *testing.T) { + // Construct raw bytes with both Version 2 AND UnsignedTx (0x00 forbidden in V2). + // Magic + Version (0xfb: 2) + UnsignedTx (0x00: minimal 1-byte tx) + raw := []byte{ + 0x70, 0x73, 0x62, 0x74, 0xff, // Magic + 0x01, 0xfb, 0x04, 0x02, 0x00, 0x00, 0x00, // Version 2 + 0x01, 0x00, 0x01, 0x01, // UnsignedTx (keyCode 0x00, minimal value) + 0x00, // Separator + } + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err) + }) + + t.Run("V2 must have InputCount and OutputCount", func(t *testing.T) { + // Create a raw V2 serialization but manually omit counts. + // Magic + Version (0xfb: 2) + TxVersion (0x02: 2) + raw := []byte{ + 0x70, 0x73, 0x62, 0x74, 0xff, // Magic + 0x01, 0xfb, 0x04, 0x02, 0x00, 0x00, 0x00, // Version 2 + 0x01, 0x02, 0x04, 0x02, 0x00, 0x00, 0x00, // TxVersion 2 + 0x00, // Separator + } + + // parsing should fail because InputCount/OutputCount are missing for V2. + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err) + }) + + t.Run("Unsupported version should fail", func(t *testing.T) { + // Version 3 (unsupported) + raw := []byte{ + 0x70, 0x73, 0x62, 0x74, 0xff, // Magic + 0x01, 0xfb, 0x04, 0x03, 0x00, 0x00, 0x00, // Version 3 + 0x00, // Separator + } + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err) + }) +} + +// TestPsbtV2Counts ensures that the number of inputs and outputs parsed +// matches the InputCount and OutputCount global fields. +func TestPsbtV2Counts(t *testing.T) { + // Create a V2 PSBT that claims 2 inputs but only provides 1. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid, _ := chainhash.NewHashFromStr("0102030405060708091011121314151617181920212223242526272829303132") + outPoint := wire.NewOutPoint(txid, 1) + err = p.AddInput(*outPoint, 0xffffffff) + require.NoError(t, err) + + script, _ := hex.DecodeString("00") + err = p.AddOutput(1000, script) + require.NoError(t, err) + + // Manually override counts to mismatch reality. + p.InputCount = 2 + p.OutputCount = 1 + + var b bytes.Buffer + err = p.Serialize(&b) + require.NoError(t, err) + + // Parsing should fail because we promised 2 inputs but only 1 followed. + _, err = NewFromRawBytes(&b, false) + require.Error(t, err) +} + +// TestPsbtV2Locktimes verifies that PSBTv2 locktime fields are correctly handled. +func TestPsbtV2Locktimes(t *testing.T) { + // Create a V2 PSBT with a fallback locktime. + p, err := NewV2(2, 500000, 0x01) // Fallback 500000, InputsModifiable + require.NoError(t, err) + + txid, _ := chainhash.NewHashFromStr("0102030405060708091011121314151617181920212223242526272829303132") + outPoint := wire.NewOutPoint(txid, 1) + err = p.AddInput(*outPoint, 0xffffffff) + require.NoError(t, err) + + p.Inputs[0].HeightLocktime = 600000 + + msgTx, err := p.GetUnsignedTx() + require.NoError(t, err) + + // Since we have a height locktime in an input, it should take precedence. + require.Equal(t, uint32(600000), msgTx.LockTime) + + // Now try with time locktime. + p.Inputs[0].HeightLocktime = 0 + p.Inputs[0].TimeLocktime = 1600000000 + msgTx, err = p.GetUnsignedTx() + require.NoError(t, err) + require.Equal(t, uint32(1600000000), msgTx.LockTime) + + // Test combined locktimes (BIP suggests highest value for same type, + // but here we just check our extraction logic). + err = p.AddInput(*outPoint, 0xffffffff) + require.NoError(t, err) + p.Inputs[1].TimeLocktime = 1700000000 + + msgTx, err = p.GetUnsignedTx() + require.NoError(t, err) + require.Equal(t, uint32(1700000000), msgTx.LockTime) + + // Test modifiability flag. + require.Equal(t, uint8(0x01), p.TxModifiable) +} + +// TestPSBTv2DetermineLockTimeAlgorithm tests the comprehensive BIP-370 lock time determination algorithm +func TestPSBTv2DetermineLockTimeAlgorithm(t *testing.T) { + t.Run("Height-based preference when both supported (BIP-370 tie-breaker)", func(t *testing.T) { + p, err := NewV2(2, 100000, 0) + require.NoError(t, err) + + txid, _ := chainhash.NewHashFromStr("1111111111111111111111111111111111111111111111111111111111111111") + + // Input 1: Both time and height (flexible) + p.AddInput(*wire.NewOutPoint(txid, 0), 0) + p.Inputs[0].TimeLocktime = 1600000000 + p.Inputs[0].HeightLocktime = 550000 + + // Input 2: Both time and height (flexible) + p.AddInput(*wire.NewOutPoint(txid, 1), 0) + p.Inputs[1].TimeLocktime = 1650000000 + p.Inputs[1].HeightLocktime = 600000 + + // BIP-370: "height-based must be chosen" when both supported + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(600000), lockTime) // Max height, NOT max time + }) + + t.Run("Conflicting requirements should error", func(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid, _ := chainhash.NewHashFromStr("2222222222222222222222222222222222222222222222222222222222222222") + + // Input 1: Time-only (cannot satisfy height) + p.AddInput(*wire.NewOutPoint(txid, 0), 0) + p.Inputs[0].TimeLocktime = 1600000000 + + // Input 2: Height-only (cannot satisfy time) + p.AddInput(*wire.NewOutPoint(txid, 1), 0) + p.Inputs[1].HeightLocktime = 500000 + + // Should fail - conflicting requirements + _, err = p.DetermineLockTime() + require.Error(t, err) + require.Equal(t, ErrInvalidPsbtFormat, err) + }) + + t.Run("Fallback locktime when no constraints", func(t *testing.T) { + fallback := uint32(123456) + p, err := NewV2(2, fallback, 0) + require.NoError(t, err) + + txid, _ := chainhash.NewHashFromStr("3333333333333333333333333333333333333333333333333333333333333333") + p.AddInput(*wire.NewOutPoint(txid, 0), 0) + // No TimeLocktime or HeightLocktime set + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, fallback, lockTime) + }) +} + +// TestPSBTv2AddUnknownFields tests the addUnknown field handling +func TestPSBTv2AddUnknownFields(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid, _ := chainhash.NewHashFromStr("4444444444444444444444444444444444444444444444444444444444444444") + p.AddInput(*wire.NewOutPoint(txid, 0), 0) + + // Test adding unknown field succeeds + err = p.Inputs[0].addUnknown(0xfc, []byte{0x01, 0x02}, []byte{0x03, 0x04}) + require.NoError(t, err) + require.Len(t, p.Inputs[0].Unknowns, 1) + + // Test duplicate detection + err = p.Inputs[0].addUnknown(0xfc, []byte{0x01, 0x02}, []byte{0x03, 0x04}) + require.Error(t, err) + require.Equal(t, ErrDuplicateKey, err) + + p.AddOutput(1000000, []byte{0x76, 0xa9, 0x14}) + + // Test output unknown fields + err = p.Outputs[0].addUnknown(0xfd, []byte{0x05}, []byte{0x06}) + require.NoError(t, err) + require.Len(t, p.Outputs[0].Unknowns, 1) +} diff --git a/btcutil/psbt/psbt_v2_test.go b/btcutil/psbt/psbt_v2_test.go new file mode 100644 index 0000000000..53abf42c9a --- /dev/null +++ b/btcutil/psbt/psbt_v2_test.go @@ -0,0 +1,1434 @@ +package psbt + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +// testTxid returns a deterministic txid for test use. +func testTxid(fill byte) *chainhash.Hash { + var h chainhash.Hash + for i := range h { + h[i] = fill + } + return &h +} + +// serializeV2Global is a helper that builds a raw v2 global section (after the +// magic bytes) from explicitly provided key-value pairs. Each pair is a key +// byte slice and a value byte slice, serialized per the PSBT wire format. +// A separator (0x00) is appended at the end. This allows tests to construct +// intentionally malformed PSBTs. +func serializeV2Global(t *testing.T, pairs ...[]byte) []byte { + t.Helper() + + require.True(t, len(pairs)%2 == 0, "pairs must be key, value, ...") + + var buf bytes.Buffer + // Magic bytes. + buf.Write(psbtMagic[:]) + + for i := 0; i < len(pairs); i += 2 { + key := pairs[i] + value := pairs[i+1] + // Write key length + key. + wire.WriteVarInt(&buf, 0, uint64(len(key))) + buf.Write(key) + // Write value length + value. + wire.WriteVarInt(&buf, 0, uint64(len(value))) + buf.Write(value) + } + + // Separator. + buf.WriteByte(0x00) + return buf.Bytes() +} + +// uint32LE returns a 4-byte little-endian encoding of v. +func uint32LE(v uint32) []byte { + b := make([]byte, 4) + binary.LittleEndian.PutUint32(b, v) + return b +} + +// compactSizeUint returns the compact-size encoding of v. +func compactSizeUint(v uint64) []byte { + var buf bytes.Buffer + wire.WriteVarInt(&buf, 0, v) + return buf.Bytes() +} + +// ========================================================================== +// 1. Creation & Round-Trip Tests +// ========================================================================== + +func TestV2CreateEmptyPSBT(t *testing.T) { + // Create a v2 PSBT with 0 inputs and 0 outputs. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + require.Equal(t, uint32(2), p.Version) + require.Equal(t, uint32(2), p.TxVersion) + require.Equal(t, uint32(0), p.InputCount) + require.Equal(t, uint32(0), p.OutputCount) + + // Round-trip serialize and parse. + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Equal(t, uint32(2), p2.Version) + require.Equal(t, uint32(0), p2.InputCount) + require.Equal(t, uint32(0), p2.OutputCount) + require.Nil(t, p2.UnsignedTx) +} + +func TestV2RoundTripAllFields(t *testing.T) { + // Create a v2 PSBT with all fields populated. + p, err := NewV2(2, 700000, 0x03) + require.NoError(t, err) + + txid := testTxid(0xAA) + require.NoError(t, p.AddInput( + *wire.NewOutPoint(txid, 5), 0xFFFFFFFE, + )) + p.Inputs[0].TimeLocktime = 1600000000 + p.Inputs[0].HeightLocktime = 400000 + + script := []byte{0x00, 0x14, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14} + require.NoError(t, p.AddOutput(50000000, script)) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + + // Global fields. + require.Equal(t, uint32(2), p2.Version) + require.Equal(t, uint32(2), p2.TxVersion) + require.Equal(t, uint32(700000), p2.FallbackLocktime) + require.Equal(t, uint8(0x03), p2.TxModifiable) + require.Equal(t, uint32(1), p2.InputCount) + require.Equal(t, uint32(1), p2.OutputCount) + + // Input fields. + require.Equal(t, txid[:], p2.Inputs[0].PreviousTxid) + require.Equal(t, uint32(5), p2.Inputs[0].OutputIndex) + require.Equal(t, uint32(0xFFFFFFFE), p2.Inputs[0].Sequence) + require.Equal(t, uint32(1600000000), p2.Inputs[0].TimeLocktime) + require.Equal(t, uint32(400000), p2.Inputs[0].HeightLocktime) + + // Output fields. + require.Equal(t, int64(50000000), p2.Outputs[0].Amount) + require.Equal(t, script, p2.Outputs[0].Script) +} + +func TestV2RoundTripBase64(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0xBB) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), wire.MaxTxInSequenceNum)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + encoded, err := p.B64Encode() + require.NoError(t, err) + + p2, err := NewFromRawBytes(bytes.NewReader([]byte(encoded)), true) + require.NoError(t, err) + require.Equal(t, uint32(2), p2.Version) + require.Equal(t, uint32(1), p2.InputCount) + require.Equal(t, uint32(1), p2.OutputCount) + require.Equal(t, txid[:], p2.Inputs[0].PreviousTxid) + require.Equal(t, int64(1000), p2.Outputs[0].Amount) +} + +func TestV2SequenceDefaultNotSerialized(t *testing.T) { + // When sequence equals MaxTxInSequenceNum (the default), it should NOT + // be serialized per BIP-370. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0xCC) + // Use the default sequence. + require.NoError(t, p.AddInput( + *wire.NewOutPoint(txid, 0), wire.MaxTxInSequenceNum, + )) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + // The serialized form should NOT contain the Sequence key (0x10). + // We search the input section for the 0x10 key type. + serialized := buf.Bytes() + // We can verify by round-tripping and checking the default is restored. + p2, err := NewFromRawBytes(bytes.NewReader(serialized), false) + require.NoError(t, err) + require.Equal(t, wire.MaxTxInSequenceNum, p2.Inputs[0].Sequence) + + // Now use a non-default sequence. + p3, err := NewV2(2, 0, 0) + require.NoError(t, err) + require.NoError(t, p3.AddInput( + *wire.NewOutPoint(txid, 0), 0, + )) + require.NoError(t, p3.AddOutput(1000, []byte{0x51})) + + var buf2 bytes.Buffer + require.NoError(t, p3.Serialize(&buf2)) + + p4, err := NewFromRawBytes(&buf2, false) + require.NoError(t, err) + require.Equal(t, uint32(0), p4.Inputs[0].Sequence) +} + +func TestV2MultipleInputsOutputs(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + // Add 3 inputs and 2 outputs. + for i := byte(1); i <= 3; i++ { + txid := testTxid(i) + require.NoError(t, p.AddInput( + *wire.NewOutPoint(txid, uint32(i)), + wire.MaxTxInSequenceNum, + )) + } + require.NoError(t, p.AddOutput(10000, []byte{0x51})) + require.NoError(t, p.AddOutput(20000, []byte{0x00, 0x14, 0xaa})) + + require.Equal(t, uint32(3), p.InputCount) + require.Equal(t, uint32(2), p.OutputCount) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Len(t, p2.Inputs, 3) + require.Len(t, p2.Outputs, 2) + require.Equal(t, uint32(3), p2.InputCount) + require.Equal(t, uint32(2), p2.OutputCount) + + for i := byte(1); i <= 3; i++ { + require.Equal(t, testTxid(i)[:], p2.Inputs[i-1].PreviousTxid) + require.Equal(t, uint32(i), p2.Inputs[i-1].OutputIndex) + } + require.Equal(t, int64(10000), p2.Outputs[0].Amount) + require.Equal(t, int64(20000), p2.Outputs[1].Amount) +} + +// ========================================================================== +// 2. Version Validation Tests +// ========================================================================== + +func TestV2CannotHaveUnsignedTx(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + // Force an UnsignedTx onto a v2 PSBT. + p.UnsignedTx = wire.NewMsgTx(2) + require.Error(t, p.SanityCheck()) +} + +func TestV2RequiredGlobalFields(t *testing.T) { + // A v2 PSBT without TxVersion should fail. + raw := serializeV2Global(t, + // Version = 2 + []byte{0xfb}, uint32LE(2), + // InputCount = 0 + []byte{0x04}, compactSizeUint(0), + // OutputCount = 0 + []byte{0x05}, compactSizeUint(0), + // Missing TxVersion (0x02)! + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "should fail without TxVersion") +} + +func TestV2RejectsVersion1(t *testing.T) { + // Version 1 is explicitly skipped per BIP-370. + raw := serializeV2Global(t, + []byte{0xfb}, uint32LE(1), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err) +} + +func TestV2RejectsVersion3(t *testing.T) { + raw := serializeV2Global(t, + []byte{0xfb}, uint32LE(3), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err) +} + +func TestV2AddInputToV0Fails(t *testing.T) { + // Create a v0 PSBT. + tx := wire.NewMsgTx(2) + tx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *wire.NewOutPoint(testTxid(0x01), 0), + Sequence: wire.MaxTxInSequenceNum, + }) + tx.AddTxOut(wire.NewTxOut(1000, []byte{0x51})) + + p, err := NewFromUnsignedTx(tx) + require.NoError(t, err) + require.Equal(t, uint32(0), p.Version) + + // Adding input to v0 should fail. + err = p.AddInputV2(PInput{ + PreviousTxid: testTxid(0x02)[:], + OutputIndex: 0, + Sequence: wire.MaxTxInSequenceNum, + }) + require.Error(t, err) +} + +// ========================================================================== +// 3. Lock Time Algorithm Tests +// ========================================================================== + +func TestV2LockTimeFallbackDefault(t *testing.T) { + // No inputs, no fallback → locktime 0. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(0), lockTime) +} + +func TestV2LockTimeFallbackExplicit(t *testing.T) { + // No input locktime constraints → use fallback. + p, err := NewV2(2, 654321, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + // No TimeLocktime or HeightLocktime set on the input. + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(654321), lockTime) +} + +func TestV2LockTimeHeightOnly(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + p.Inputs[0].HeightLocktime = 300000 + + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + p.Inputs[1].HeightLocktime = 400000 + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(400000), lockTime) // Max of heights. +} + +func TestV2LockTimeTimeOnly(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + p.Inputs[0].TimeLocktime = 1600000000 + + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + p.Inputs[1].TimeLocktime = 1700000000 + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(1700000000), lockTime) // Max of times. +} + +func TestV2LockTimeBothSupportedPrefersHeight(t *testing.T) { + // BIP-370: When both types are supported, height must be chosen. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + // Input with both types → supports either. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + p.Inputs[0].TimeLocktime = 1600000000 + p.Inputs[0].HeightLocktime = 300000 + + // Input with both types → supports either. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + p.Inputs[1].TimeLocktime = 1700000000 + p.Inputs[1].HeightLocktime = 400000 + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(400000), lockTime) // Height preferred. +} + +func TestV2LockTimeConflictErrors(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + // Input 1: Time-only → cannot satisfy height. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + p.Inputs[0].TimeLocktime = 1600000000 + + // Input 2: Height-only → cannot satisfy time. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + p.Inputs[1].HeightLocktime = 300000 + + _, err = p.DetermineLockTime() + require.Error(t, err) + require.Equal(t, ErrInvalidPsbtFormat, err) +} + +func TestV2LockTimeMixedFlexibleAndFixed(t *testing.T) { + // One input requires time-only, another supports both → time wins. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + // Input 1: Time-only (forces time). + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + p.Inputs[0].TimeLocktime = 1600000000 + + // Input 2: Both → flexible. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + p.Inputs[1].TimeLocktime = 1700000000 + p.Inputs[1].HeightLocktime = 500000 + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(1700000000), lockTime) +} + +func TestV2LockTimeUnconstrainedInputsIgnored(t *testing.T) { + // Unconstrained inputs (no locktime fields) don't affect the choice. + p, err := NewV2(2, 99, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + // Input 1: No locktime → unconstrained. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + + // Input 2: Height-only. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + p.Inputs[1].HeightLocktime = 400000 + + // Input 3: No locktime → unconstrained. + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 2), 0)) + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(400000), lockTime) // Not fallback. +} + +// ========================================================================== +// 4. GetUnsignedTx Tests +// ========================================================================== + +func TestV2GetUnsignedTx(t *testing.T) { + p, err := NewV2(2, 500000, 0) + require.NoError(t, err) + + txid := testTxid(0xDD) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 3), 42)) + require.NoError(t, p.AddOutput(100000, []byte{0x51})) + + // Set a height locktime. + p.Inputs[0].HeightLocktime = 600000 + + tx, err := p.GetUnsignedTx() + require.NoError(t, err) + + require.Equal(t, int32(2), tx.Version) + require.Len(t, tx.TxIn, 1) + require.Len(t, tx.TxOut, 1) + require.Equal(t, txid[:], tx.TxIn[0].PreviousOutPoint.Hash[:]) + require.Equal(t, uint32(3), tx.TxIn[0].PreviousOutPoint.Index) + require.Equal(t, uint32(42), tx.TxIn[0].Sequence) + require.Equal(t, int64(100000), tx.TxOut[0].Value) + require.Equal(t, uint32(600000), tx.LockTime) +} + +func TestV2GetUnsignedTxDoesNotMutate(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0xEE) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + tx1, err := p.GetUnsignedTx() + require.NoError(t, err) + + tx2, err := p.GetUnsignedTx() + require.NoError(t, err) + + // Mutating one should not affect the other. + tx1.TxIn[0].Sequence = 999 + require.NotEqual(t, tx1.TxIn[0].Sequence, tx2.TxIn[0].Sequence) +} + +func TestV0GetUnsignedTxStillWorks(t *testing.T) { + // Ensure the v0 path is not broken. + tx := wire.NewMsgTx(2) + tx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *wire.NewOutPoint(testTxid(0x01), 0), + Sequence: wire.MaxTxInSequenceNum, + }) + tx.AddTxOut(wire.NewTxOut(1000, []byte{0x51})) + + p, err := NewFromUnsignedTx(tx) + require.NoError(t, err) + + tx2, err := p.GetUnsignedTx() + require.NoError(t, err) + require.Equal(t, tx.TxIn[0].PreviousOutPoint, tx2.TxIn[0].PreviousOutPoint) + require.Equal(t, tx.TxOut[0].Value, tx2.TxOut[0].Value) +} + +// ========================================================================== +// 5. Locktime Value Validation Tests +// ========================================================================== + +func TestV2TimeLocktimeMustBeGTE500M(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + // Set an invalid time locktime (< 500M). + p.Inputs[0].TimeLocktime = 499999999 + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + // Parsing should reject the invalid value. + _, err = NewFromRawBytes(&buf, false) + require.Error(t, err, "time locktime < 500000000 must be rejected") +} + +func TestV2TimeLocktimeBoundary(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + // Exactly 500M should be valid. + p.Inputs[0].TimeLocktime = 500000000 + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Equal(t, uint32(500000000), p2.Inputs[0].TimeLocktime) +} + +func TestV2HeightLocktimeMustBeGTZeroAndLT500M(t *testing.T) { + tests := []struct { + name string + height uint32 + }{ + {name: "zero", height: 0}, + {name: "exactly 500M", height: 500000000}, + {name: "above 500M", height: 600000000}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build raw PSBT with explicit HeightLocktime value + // to bypass the serialization skip for zero values. + txid := testTxid(0x01) + + raw := serializeV2WithInputKVPairs(t, + []byte{0x0e}, txid[:], + []byte{0x0f}, uint32LE(0), + []byte{0x12}, uint32LE(tc.height), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, + "height locktime %d must be rejected", + tc.height) + }) + } +} + +func TestV2HeightLocktimeValidBoundary(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + // Height 1 is the minimum valid value. + p.Inputs[0].HeightLocktime = 1 + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Equal(t, uint32(1), p2.Inputs[0].HeightLocktime) + + // Height 499999999 is the maximum valid value. + p3, err := NewV2(2, 0, 0) + require.NoError(t, err) + require.NoError(t, p3.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p3.AddOutput(1000, []byte{0x51})) + p3.Inputs[0].HeightLocktime = 499999999 + + var buf2 bytes.Buffer + require.NoError(t, p3.Serialize(&buf2)) + p4, err := NewFromRawBytes(&buf2, false) + require.NoError(t, err) + require.Equal(t, uint32(499999999), p4.Inputs[0].HeightLocktime) +} + +// ========================================================================== +// 6. Duplicate Field Detection Tests +// ========================================================================== + +func TestV2DuplicateGlobalFallbackLocktime(t *testing.T) { + // Build a raw v2 PSBT with FallbackLocktime (0x03) appearing twice. + raw := serializeV2Global(t, + // TxVersion + []byte{0x02}, uint32LE(2), + // FallbackLocktime first + []byte{0x03}, uint32LE(0), + // FallbackLocktime duplicate + []byte{0x03}, uint32LE(0), + // InputCount + []byte{0x04}, compactSizeUint(0), + // OutputCount + []byte{0x05}, compactSizeUint(0), + // Version + []byte{0xfb}, uint32LE(2), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "duplicate FallbackLocktime must be rejected") +} + +func TestV2DuplicateGlobalTxModifiable(t *testing.T) { + raw := serializeV2Global(t, + []byte{0x02}, uint32LE(2), + []byte{0x04}, compactSizeUint(0), + []byte{0x05}, compactSizeUint(0), + // TxModifiable first + []byte{0x06}, []byte{0x00}, + // TxModifiable duplicate + []byte{0x06}, []byte{0x00}, + []byte{0xfb}, uint32LE(2), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "duplicate TxModifiable must be rejected") +} + +// serializeV2WithInputKVPairs builds a minimal v2 PSBT with one input, where +// the input section contains the given raw key-value pairs. +func serializeV2WithInputKVPairs(t *testing.T, pairs ...[]byte) []byte { + t.Helper() + require.True(t, len(pairs)%2 == 0) + + var buf bytes.Buffer + buf.Write(psbtMagic[:]) + + // Global: TxVersion=2, InputCount=1, OutputCount=1, Version=2. + for _, pair := range []struct{ key, val []byte }{ + {[]byte{0x02}, uint32LE(2)}, + {[]byte{0x04}, compactSizeUint(1)}, + {[]byte{0x05}, compactSizeUint(1)}, + {[]byte{0xfb}, uint32LE(2)}, + } { + wire.WriteVarInt(&buf, 0, uint64(len(pair.key))) + buf.Write(pair.key) + wire.WriteVarInt(&buf, 0, uint64(len(pair.val))) + buf.Write(pair.val) + } + buf.WriteByte(0x00) // global separator + + // Input section. + for i := 0; i < len(pairs); i += 2 { + key := pairs[i] + val := pairs[i+1] + wire.WriteVarInt(&buf, 0, uint64(len(key))) + buf.Write(key) + wire.WriteVarInt(&buf, 0, uint64(len(val))) + buf.Write(val) + } + buf.WriteByte(0x00) // input separator + + // Output section: Amount + Script. + for _, pair := range []struct{ key, val []byte }{ + {[]byte{0x03}, []byte{0xe8, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {[]byte{0x04}, []byte{0x51}}, + } { + wire.WriteVarInt(&buf, 0, uint64(len(pair.key))) + buf.Write(pair.key) + wire.WriteVarInt(&buf, 0, uint64(len(pair.val))) + buf.Write(pair.val) + } + buf.WriteByte(0x00) // output separator + + return buf.Bytes() +} + +func TestV2DuplicateInputOutputIndex(t *testing.T) { + txid := testTxid(0x01) + + raw := serializeV2WithInputKVPairs(t, + // PreviousTxid + []byte{0x0e}, txid[:], + // OutputIndex first + []byte{0x0f}, uint32LE(0), + // OutputIndex duplicate + []byte{0x0f}, uint32LE(1), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "duplicate OutputIndex must be rejected") +} + +func TestV2DuplicateInputSequence(t *testing.T) { + txid := testTxid(0x01) + + raw := serializeV2WithInputKVPairs(t, + []byte{0x0e}, txid[:], + []byte{0x0f}, uint32LE(0), + // Sequence first + []byte{0x10}, uint32LE(0), + // Sequence duplicate + []byte{0x10}, uint32LE(1), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "duplicate Sequence must be rejected") +} + +func TestV2DuplicateInputTimeLocktime(t *testing.T) { + txid := testTxid(0x01) + + raw := serializeV2WithInputKVPairs(t, + []byte{0x0e}, txid[:], + []byte{0x0f}, uint32LE(0), + // TimeLocktime first + []byte{0x11}, uint32LE(500000000), + // TimeLocktime duplicate + []byte{0x11}, uint32LE(600000000), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "duplicate TimeLocktime must be rejected") +} + +func TestV2DuplicateInputHeightLocktime(t *testing.T) { + txid := testTxid(0x01) + + raw := serializeV2WithInputKVPairs(t, + []byte{0x0e}, txid[:], + []byte{0x0f}, uint32LE(0), + // HeightLocktime first + []byte{0x12}, uint32LE(100), + // HeightLocktime duplicate + []byte{0x12}, uint32LE(200), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "duplicate HeightLocktime must be rejected") +} + +// ========================================================================== +// 7. Input Serialization Key Ordering Tests +// ========================================================================== + +func TestV2InputSerializationKeyOrder(t *testing.T) { + // Build a v2 input with various fields and verify keys are in ascending + // order after serialization. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0xFF) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 42)) + p.Inputs[0].HeightLocktime = 100000 + + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + // Parse back and extract the serialized input section to verify order. + // We'll re-serialize the parsed packet and check key order in the input + // section by scanning for key types. + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + + // Serialize the input and extract key types. + var inputBuf bytes.Buffer + require.NoError(t, p2.Inputs[0].serialize(&inputBuf, 2)) + + keyTypes := extractKeyTypes(t, inputBuf.Bytes()) + for i := 1; i < len(keyTypes); i++ { + require.True(t, keyTypes[i] >= keyTypes[i-1], + "key type 0x%02x must come after 0x%02x", + keyTypes[i], keyTypes[i-1]) + } +} + +// extractKeyTypes reads the serialized key-value pairs and returns just the key +// type bytes in order. +func extractKeyTypes(t *testing.T, data []byte) []byte { + t.Helper() + + r := bytes.NewReader(data) + var keyTypes []byte + + for { + keyLen, err := wire.ReadVarInt(r, 0) + if err != nil { + break + } + if keyLen == 0 { + break + } + + key := make([]byte, keyLen) + _, err = r.Read(key) + require.NoError(t, err) + + keyTypes = append(keyTypes, key[0]) + + // Read and discard value. + valLen, err := wire.ReadVarInt(r, 0) + require.NoError(t, err) + val := make([]byte, valLen) + _, err = r.Read(val) + require.NoError(t, err) + } + + return keyTypes +} + +// ========================================================================== +// 8. SumUtxoInputValues & GetTxFee v2 Tests +// ========================================================================== + +func TestV2SumUtxoInputValuesWitness(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 1), 0)) + + // Set witness UTXOs. + p.Inputs[0].WitnessUtxo = wire.NewTxOut(50000, []byte{0x51}) + p.Inputs[1].WitnessUtxo = wire.NewTxOut(30000, []byte{0x51}) + + sum, err := SumUtxoInputValues(p) + require.NoError(t, err) + require.Equal(t, int64(80000), sum) +} + +func TestV2SumUtxoInputValuesNonWitness(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + // Create a "previous transaction" with outputs. + prevTx := wire.NewMsgTx(2) + prevTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *wire.NewOutPoint(testTxid(0xFF), 0), + }) + prevTx.AddTxOut(wire.NewTxOut(10000, []byte{0x51})) + prevTx.AddTxOut(wire.NewTxOut(20000, []byte{0x51})) + prevTx.AddTxOut(wire.NewTxOut(30000, []byte{0x51})) + + // Input spending output index 2 of prevTx. + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 2), 0)) + p.Inputs[0].NonWitnessUtxo = prevTx + + sum, err := SumUtxoInputValues(p) + require.NoError(t, err) + require.Equal(t, int64(30000), sum) +} + +func TestV2SumUtxoInputValuesNoUtxoError(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + // No UTXO set. + + _, err = SumUtxoInputValues(p) + require.Error(t, err) +} + +func TestV2GetTxFee(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + p.Inputs[0].WitnessUtxo = wire.NewTxOut(100000, []byte{0x51}) + + require.NoError(t, p.AddOutput(90000, []byte{0x51})) + + fee, err := p.GetTxFee() + require.NoError(t, err) + require.Equal(t, int64(10000), int64(fee)) +} + +// ========================================================================== +// 9. CopyInputFields / Finalization Preservation Tests +// ========================================================================== + +func TestCopyInputFieldsPreservesV2Fields(t *testing.T) { + src := &PInput{ + PreviousTxid: testTxid(0xAA)[:], + OutputIndex: 7, + Sequence: 42, + TimeLocktime: 1600000000, + HeightLocktime: 300000, + Unknowns: []*Unknown{ + {Key: []byte{0xfc, 0x01}, Value: []byte{0x02}}, + }, + } + + dst := NewPsbtInput(nil, nil) + dst.CopyInputFields(src) + + require.Equal(t, src.PreviousTxid, dst.PreviousTxid) + require.Equal(t, src.OutputIndex, dst.OutputIndex) + require.Equal(t, src.Sequence, dst.Sequence) + require.Equal(t, src.TimeLocktime, dst.TimeLocktime) + require.Equal(t, src.HeightLocktime, dst.HeightLocktime) + require.Len(t, dst.Unknowns, 1) + require.Equal(t, src.Unknowns[0].Key, dst.Unknowns[0].Key) + require.Equal(t, src.Unknowns[0].Value, dst.Unknowns[0].Value) + + // Verify deep copy: mutating dst should not affect src. + dst.Unknowns[0].Value[0] = 0xFF + require.NotEqual(t, src.Unknowns[0].Value[0], dst.Unknowns[0].Value[0]) +} + +// ========================================================================== +// 10. SanityCheck Tests +// ========================================================================== + +func TestV2SanityCheckRejectsUnsignedTx(t *testing.T) { + p := &Packet{ + Version: 2, + TxVersion: 2, + UnsignedTx: wire.NewMsgTx(2), + } + require.Error(t, p.SanityCheck()) +} + +func TestV0SanityCheckRequiresUnsignedTx(t *testing.T) { + p := &Packet{ + Version: 0, + UnsignedTx: nil, + } + require.Error(t, p.SanityCheck()) +} + +// ========================================================================== +// 11. PreviousTxid Validation Tests +// ========================================================================== + +func TestV2RejectsAllZeroPreviousTxid(t *testing.T) { + zeroTxid := make([]byte, 32) + + raw := serializeV2WithInputKVPairs(t, + []byte{0x0e}, zeroTxid, + []byte{0x0f}, uint32LE(0), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "all-zero PreviousTxid must be rejected") +} + +func TestV2RejectsWrongLengthPreviousTxid(t *testing.T) { + shortTxid := make([]byte, 16) // Should be 32 bytes. + shortTxid[0] = 0x01 + + raw := serializeV2WithInputKVPairs(t, + []byte{0x0e}, shortTxid, + []byte{0x0f}, uint32LE(0), + ) + + _, err := NewFromRawBytes(bytes.NewReader(raw), false) + require.Error(t, err, "wrong-length PreviousTxid must be rejected") +} + +// ========================================================================== +// 12. Input/Output Count Mismatch Tests +// ========================================================================== + +func TestV2InputCountMismatchFails(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + // Override count to claim more inputs. + p.InputCount = 3 + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + // Should fail because we claimed 3 inputs but only provided 1. + _, err = NewFromRawBytes(&buf, false) + require.Error(t, err) +} + +func TestV2OutputCountMismatchFails(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + p.OutputCount = 2 + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + _, err = NewFromRawBytes(&buf, false) + require.Error(t, err) +} + +// ========================================================================== +// 13. Amount Type Tests +// ========================================================================== + +func TestV2AmountSignedInt64(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + + // Use a large but valid amount. + require.NoError(t, p.AddOutput(2100000000000000, []byte{0x51})) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Equal(t, int64(2100000000000000), p2.Outputs[0].Amount) + + // Verify it converts correctly to a wire transaction. + tx, err := p2.GetUnsignedTx() + require.NoError(t, err) + require.Equal(t, int64(2100000000000000), tx.TxOut[0].Value) +} + +// ========================================================================== +// 14. Unknown Fields Round-Trip Tests +// ========================================================================== + +func TestV2UnknownFieldsRoundTrip(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + // Add unknown fields to input and output. + require.NoError(t, p.Inputs[0].addUnknown( + 0xfc, []byte{0x01, 0x02}, []byte{0x03, 0x04}, + )) + require.NoError(t, p.Outputs[0].addUnknown( + 0xf1, []byte{0x05}, []byte{0x06, 0x07}, + )) + + // Global unknown (use key type < 0xfd to avoid varint prefix issues). + p.Unknowns = append(p.Unknowns, &Unknown{ + Key: []byte{0xf0, 0x01}, + Value: []byte{0x02, 0x03}, + }) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + + require.Len(t, p2.Inputs[0].Unknowns, 1) + require.Equal(t, []byte{0xfc, 0x01, 0x02}, p2.Inputs[0].Unknowns[0].Key) + require.Equal(t, []byte{0x03, 0x04}, p2.Inputs[0].Unknowns[0].Value) + + require.Len(t, p2.Outputs[0].Unknowns, 1) + require.Equal(t, []byte{0xf1, 0x05}, p2.Outputs[0].Unknowns[0].Key) + require.Equal(t, []byte{0x06, 0x07}, p2.Outputs[0].Unknowns[0].Value) + + require.Len(t, p2.Unknowns, 1) + require.Equal(t, []byte{0xf0, 0x01}, p2.Unknowns[0].Key) +} + +// ========================================================================== +// 15. Signer / Updater v2 Compatibility Tests +// ========================================================================== + +func TestV2UpdaterCreation(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + u, err := NewUpdater(p) + require.NoError(t, err) + require.NotNil(t, u) +} + +func TestV2UpdaterAddWitnessUtxo(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + u, err := NewUpdater(p) + require.NoError(t, err) + + utxo := wire.NewTxOut(50000, []byte{0x00, 0x14, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14}) + require.NoError(t, u.AddInWitnessUtxo(utxo, 0)) + require.Equal(t, utxo, p.Inputs[0].WitnessUtxo) +} + +// ========================================================================== +// 16. IsComplete / Extraction Tests +// ========================================================================== + +func TestV2IsCompleteReturnsFalseWhenNotFinalized(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + require.False(t, p.IsComplete()) +} + +func TestV2ExtractRejectsIncomplete(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + _, err = Extract(p) + require.Error(t, err) + require.Equal(t, ErrIncompletePSBT, err) +} + +// ========================================================================== +// 17. Edge Cases +// ========================================================================== + +func TestV2ZeroFallbackLocktime(t *testing.T) { + // Explicitly set fallback locktime to 0 (the default). + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + require.NoError(t, p.AddInput( + *wire.NewOutPoint(testTxid(0x01), 0), 0, + )) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Equal(t, uint32(0), p2.FallbackLocktime) +} + +func TestV2TxModifiableFlags(t *testing.T) { + tests := []struct { + name string + flags uint8 + }{ + {name: "none", flags: 0x00}, + {name: "inputs modifiable", flags: 0x01}, + {name: "outputs modifiable", flags: 0x02}, + {name: "both modifiable", flags: 0x03}, + {name: "sighash single", flags: 0x04}, + {name: "all flags", flags: 0x07}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p, err := NewV2(2, 0, tc.flags) + require.NoError(t, err) + require.NoError(t, p.AddInput( + *wire.NewOutPoint(testTxid(0x01), 0), 0, + )) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + require.Equal(t, tc.flags, p2.TxModifiable) + }) + } +} + +func TestV2NewFromUnsignedTxPopulatesV2Fields(t *testing.T) { + // Verify that NewFromUnsignedTx pre-populates v2-compatible fields on + // PInput, even though the packet is v0. + txid := testTxid(0x01) + tx := wire.NewMsgTx(2) + tx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *wire.NewOutPoint(txid, 7), + Sequence: 42, + }) + tx.AddTxOut(wire.NewTxOut(1000, []byte{0x51})) + + p, err := NewFromUnsignedTx(tx) + require.NoError(t, err) + require.Equal(t, uint32(0), p.Version) + + // v2-compatible fields should be populated. + require.Equal(t, txid[:], p.Inputs[0].PreviousTxid) + require.Equal(t, uint32(7), p.Inputs[0].OutputIndex) + require.Equal(t, uint32(42), p.Inputs[0].Sequence) +} + +func TestV2LockTimeInGetUnsignedTx(t *testing.T) { + // Verify that the locktime in the extracted transaction matches the + // DetermineLockTime result. + p, err := NewV2(2, 100, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, 0), 0)) + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + // No input locktimes → fallback. + tx, err := p.GetUnsignedTx() + require.NoError(t, err) + require.Equal(t, uint32(100), tx.LockTime) +} + +func TestV2MultipleInputsLockTimeMax(t *testing.T) { + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + + txid := testTxid(0x01) + for i := uint32(0); i < 5; i++ { + require.NoError(t, p.AddInput(*wire.NewOutPoint(txid, i), 0)) + p.Inputs[i].HeightLocktime = 100000 + i*50000 + } + require.NoError(t, p.AddOutput(1000, []byte{0x51})) + + lockTime, err := p.DetermineLockTime() + require.NoError(t, err) + require.Equal(t, uint32(300000), lockTime) // 100000 + 4*50000 +} + +// ========================================================================== +// 11. Creator Validation Tests +// ========================================================================== + +// TestNewV2RejectsBadTxVersion verifies that the PSBTv2 Creator rejects a +// transaction version below 2, as required by BIP-370. +func TestNewV2RejectsBadTxVersion(t *testing.T) { + tests := []struct { + name string + txVersion uint32 + }{ + {name: "version 0", txVersion: 0}, + {name: "version 1", txVersion: 1}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewV2(tc.txVersion, 0, 0) + require.Error(t, err, + "NewV2 with txVersion %d must be rejected", tc.txVersion) + }) + } + + // Version 2 is the minimum valid value. + p, err := NewV2(2, 0, 0) + require.NoError(t, err) + require.Equal(t, uint32(2), p.TxVersion) +} + +// ========================================================================== +// 12. Updater Role Modifiable Flag Tests +// ========================================================================== + +// TestUpdaterAddInputV2RespectsTxModifiable verifies that the Updater-role +// AddInputV2 enforces the PSBT_GLOBAL_TX_MODIFIABLE Inputs Modifiable flag +// (Bit 0) per BIP-370, unlike the Creator-role Packet.AddInputV2. +func TestUpdaterAddInputV2RespectsTxModifiable(t *testing.T) { + // Build a valid base packet to start from. + makePkt := func(modifiable uint8) *Packet { + p, err := NewV2(2, 0, modifiable) + require.NoError(t, err) + return p + } + + txid := testTxid(0xAB) + input := PInput{ + PreviousTxid: txid[:], + OutputIndex: 0, + Sequence: wire.MaxTxInSequenceNum, + } + + t.Run("fails when inputs not modifiable (bit 0 clear)", func(t *testing.T) { + p := makePkt(0x00) // No bits set. + u := &Updater{Upsbt: p} + err := u.AddInputV2(input) + require.Error(t, err) + }) + + t.Run("fails when only outputs modifiable (bit 1 set, bit 0 clear)", func(t *testing.T) { + p := makePkt(0x02) // Bit 1 only — outputs modifiable, not inputs. + u := &Updater{Upsbt: p} + err := u.AddInputV2(input) + require.Error(t, err) + }) + + t.Run("succeeds when inputs modifiable (bit 0 set)", func(t *testing.T) { + p := makePkt(0x01) // Bit 0 set — inputs modifiable. + u := &Updater{Upsbt: p} + require.NoError(t, u.AddInputV2(input)) + require.Len(t, p.Inputs, 1) + require.Equal(t, uint32(1), p.InputCount) + }) + + t.Run("creator-role AddInputV2 ignores modifiable flag", func(t *testing.T) { + // Packet.AddInputV2 is the Creator role — no flag restriction. + p := makePkt(0x00) + require.NoError(t, p.AddInputV2(input)) + require.Len(t, p.Inputs, 1) + }) +} + +// TestUpdaterAddOutputV2RespectsTxModifiable verifies that the Updater-role +// AddOutputV2 enforces the PSBT_GLOBAL_TX_MODIFIABLE Outputs Modifiable flag +// (Bit 1) per BIP-370. +func TestUpdaterAddOutputV2RespectsTxModifiable(t *testing.T) { + makePkt := func(modifiable uint8) *Packet { + p, err := NewV2(2, 0, modifiable) + require.NoError(t, err) + return p + } + + output := POutput{ + Amount: 100000, + Script: []byte{0x51}, + } + + t.Run("fails when outputs not modifiable (bit 1 clear)", func(t *testing.T) { + p := makePkt(0x00) + u := &Updater{Upsbt: p} + err := u.AddOutputV2(output) + require.Error(t, err) + }) + + t.Run("fails when only inputs modifiable (bit 0 set, bit 1 clear)", func(t *testing.T) { + p := makePkt(0x01) // Bit 0 only — inputs modifiable, not outputs. + u := &Updater{Upsbt: p} + err := u.AddOutputV2(output) + require.Error(t, err) + }) + + t.Run("succeeds when outputs modifiable (bit 1 set)", func(t *testing.T) { + p := makePkt(0x02) // Bit 1 set — outputs modifiable. + u := &Updater{Upsbt: p} + require.NoError(t, u.AddOutputV2(output)) + require.Len(t, p.Outputs, 1) + require.Equal(t, uint32(1), p.OutputCount) + }) + + t.Run("creator-role AddOutputV2 ignores modifiable flag", func(t *testing.T) { + // Packet.AddOutputV2 is the Creator role — no flag restriction. + p := makePkt(0x00) + require.NoError(t, p.AddOutputV2(output)) + require.Len(t, p.Outputs, 1) + }) +} + +// ========================================================================== +// 13. V0 Rejects V2-Only Fields Tests +// ========================================================================== + +// TestV0RejectsV2InputFields verifies that when a v0 PSBT contains v2-only +// input fields (0x0e–0x12), they are routed to the Unknowns list instead of +// being parsed as named fields, as required by BIP-370. +func TestV0RejectsV2InputFields(t *testing.T) { + // Build a v0 PSBT from an unsigned tx. + tx := wire.NewMsgTx(2) + txid := testTxid(0x01) + tx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *wire.NewOutPoint(txid, 0), + Sequence: wire.MaxTxInSequenceNum, + }) + tx.AddTxOut(wire.NewTxOut(50000, []byte{0x51})) + + p, err := NewFromUnsignedTx(tx) + require.NoError(t, err) + require.Equal(t, uint32(0), p.Version) + + // Inject a v2-only field (PreviousTxid = 0x0e) directly into the + // unknowns, simulating a PSBT that has this field embedded. + // After a round-trip through serialize/parse, PreviousTxid must NOT + // be populated — it must live in Unknowns. + p.Inputs[0].Unknowns = append(p.Inputs[0].Unknowns, &Unknown{ + Key: []byte{byte(PreviousTxidInputType)}, + Value: txid[:], + }) + + var buf bytes.Buffer + require.NoError(t, p.Serialize(&buf)) + + p2, err := NewFromRawBytes(&buf, false) + require.NoError(t, err) + + // PreviousTxid must NOT have been parsed as a named field in v0. + require.Nil(t, p2.Inputs[0].PreviousTxid, + "PreviousTxid must not be parsed for v0 PSBTs") + + // It must appear in Unknowns instead. + require.NotEmpty(t, p2.Inputs[0].Unknowns, + "v2-only field must be routed to Unknowns in v0") +} diff --git a/btcutil/psbt/signer.go b/btcutil/psbt/signer.go index dcbcf93fa3..e73c0e3110 100644 --- a/btcutil/psbt/signer.go +++ b/btcutil/psbt/signer.go @@ -4,7 +4,7 @@ package psbt -// signer encapsulates the role 'Signer' as specified in BIP174; it controls +// signer encapsulates the role 'Signer' as specified in BIP174 and BIP0370; it controls // the insertion of signatures; the Sign() function will attempt to insert // signatures using Updater.addPartialSignature, after first ensuring the Psbt // is in the correct state. @@ -115,8 +115,13 @@ func (u *Updater) Sign(inIndex int, sig []byte, pubKey []byte, // output. default: if pInput.WitnessUtxo == nil { - txIn := u.Upsbt.UnsignedTx.TxIn[inIndex] - outIndex := txIn.PreviousOutPoint.Index + var outIndex uint32 + switch u.Upsbt.Version { + case 2: + outIndex = pInput.OutputIndex + default: + outIndex = u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + } script := pInput.NonWitnessUtxo.TxOut[outIndex].PkScript if txscript.IsWitnessProgram(script) { @@ -141,7 +146,12 @@ func (u *Updater) Sign(inIndex int, sig []byte, pubKey []byte, // NonWitnessUtxo field with a WitnessUtxo field. See // https://github.com/bitcoin/bitcoin/pull/14197. func nonWitnessToWitness(p *Packet, inIndex int) error { - outIndex := p.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + var outIndex uint32 + if p.Version == 2 { + outIndex = p.Inputs[inIndex].OutputIndex + } else { + outIndex = p.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + } txout := p.Inputs[inIndex].NonWitnessUtxo.TxOut[outIndex] // TODO(guggero): For segwit v1, we'll want to remove the NonWitnessUtxo diff --git a/btcutil/psbt/types.go b/btcutil/psbt/types.go index ca555101b9..33c6f5d971 100644 --- a/btcutil/psbt/types.go +++ b/btcutil/psbt/types.go @@ -44,6 +44,31 @@ const ( // // The value is any data as defined by the proprietary type user. ProprietaryGlobalType = 0xFC + + // TxVersion is the PSBT version number. + // The key is {0x02}. + // The value is a 32-bit little endian unsigned integer for the version number. + TxVersionGlobalType GlobalType = 0x02 + + // FallbackLocktime is the fallback locktime for the transaction. + // The key is {0x03}. + // The value is a 32-bit little endian unsigned integer. + FallbackLocktimeGlobalType GlobalType = 0x03 + + // InputCount is the number of inputs in this PSBT. + // The key is {0x04}. + // The value is a compact size unsigned integer. + InputCountGlobalType GlobalType = 0x04 + + // OutputCount is the number of outputs in this PSBT. + // The key is {0x05}. + // The value is a compact size unsigned integer. + OutputCountGlobalType GlobalType = 0x05 + + // TxModifiable is a bitfield denoting the modifiability of the transaction. + // The key is {0x06}. + // The value is an 8-bit unsigned integer. + TxModifiableGlobalType GlobalType = 0x06 ) // InputType is the set of types that are defined for each input included @@ -159,6 +184,31 @@ const ( // // The value is any value data as defined by the proprietary type user. ProprietaryInputType InputType = 0xFC + + // PreviousTxid is the txid of the previous transaction. + // The key is {0x0E}. + // The value is a 32-byte txid. + PreviousTxidInputType InputType = 0x0E + + // OutputIndex is the output index of the previous transaction. + // The key is {0x0F}. + // The value is a 32-bit little endian unsigned integer. + OutputIndexInputType InputType = 0x0F + + // TimeLocktime is the time-based locktime for this input. + // The key is {0x11}. + // The value is a 32-bit little endian unsigned integer. + TimeLocktimeInputType InputType = 0x11 + + // HeightLocktime is the height-based locktime for this input. + // The key is {0x12}. + // The value is a 32-bit little endian unsigned integer. + HeightLocktimeInputType InputType = 0x12 + + // Sequence is the sequence number for this input. + // The key is {0x10}. + // The value is a 32-bit little endian unsigned integer. + SequenceInputType InputType = 0x10 ) // OutputType is the set of types defined per output within the PSBT. @@ -200,4 +250,14 @@ const ( // followed by said number of 32-byte leaf hashes. The rest of the value // is then identical to the Bip32DerivationInputType value. TaprootBip32DerivationOutputType OutputType = 7 + + // Amount is the value of this output. + // The key is {0x03}. + // The value is an 8-byte little endian unsigned integer. + AmountOutputType OutputType = 0x03 + + // Script is the locking script for this output. + // The key is {0x04}. + // The value is the scriptPubkey. + ScriptOutputType OutputType = 0x04 ) diff --git a/btcutil/psbt/updater.go b/btcutil/psbt/updater.go index 66c8d1d83c..ef66e1ed3e 100644 --- a/btcutil/psbt/updater.go +++ b/btcutil/psbt/updater.go @@ -13,13 +13,15 @@ package psbt import ( "bytes" "crypto/sha256" + "errors" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" ) -// Updater encapsulates the role 'Updater' as specified in BIP174; it accepts +// Updater encapsulates the role 'Updater' as specified in BIP174 and BIP0370; it accepts // Psbt structs and has methods to add fields to the inputs and outputs. type Updater struct { Upsbt *Packet @@ -109,12 +111,21 @@ func (u *Updater) addPartialSignature(inIndex int, sig []byte, // Next, we perform a series of additional sanity checks. if pInput.NonWitnessUtxo != nil { - if len(u.Upsbt.UnsignedTx.TxIn) < inIndex+1 { + numInputs := len(u.Upsbt.Inputs) + if u.Upsbt.Version < 2 { + numInputs = len(u.Upsbt.UnsignedTx.TxIn) + } + + if numInputs < inIndex+1 { return ErrInvalidPrevOutNonWitnessTransaction } - if pInput.NonWitnessUtxo.TxHash() != - u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Hash { + var prevHash chainhash.Hash + copy(prevHash[:], pInput.PreviousTxid) + if u.Upsbt.Version < 2 { + prevHash = u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Hash + } + if pInput.NonWitnessUtxo.TxHash() != prevHash { return ErrInvalidSignatureForInput } @@ -123,7 +134,11 @@ func (u *Updater) addPartialSignature(inIndex int, sig []byte, // that with the P2SH scriptPubKey that is generated by // redeemScript. if pInput.RedeemScript != nil { - outIndex := u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + + outIndex := pInput.OutputIndex + if u.Upsbt.Version < 2 { + outIndex = u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index + } scriptPubKey := pInput.NonWitnessUtxo.TxOut[outIndex].PkScript scriptHash := btcutil.Hash160(pInput.RedeemScript) @@ -375,3 +390,33 @@ func (u *Updater) AddOutWitnessScript(witnessScript []byte, return nil } + +// AddInputV2 appends a new PInput to a Version 2 PSBT as an Updater. +// Unlike the Creator-role Packet.AddInputV2, this enforces the +// PSBT_GLOBAL_TX_MODIFIABLE Inputs-Modifiable flag (Bit 0) per BIP-370. +func (u *Updater) AddInputV2(input PInput) error { + if u.Upsbt.Version != 2 { + return errors.New("cannot dynamically add inputs to a non-v2 PSBT") + } + if u.Upsbt.TxModifiable&1 == 0 { + return errors.New("inputs are not modifiable in this PSBT") + } + u.Upsbt.Inputs = append(u.Upsbt.Inputs, input) + u.Upsbt.InputCount = uint32(len(u.Upsbt.Inputs)) + return nil +} + +// AddOutputV2 appends a new POutput to a Version 2 PSBT as an Updater. +// Unlike the Creator-role Packet.AddOutputV2, this enforces the +// PSBT_GLOBAL_TX_MODIFIABLE Outputs-Modifiable flag (Bit 1) per BIP-370. +func (u *Updater) AddOutputV2(output POutput) error { + if u.Upsbt.Version != 2 { + return errors.New("cannot dynamically add outputs to a non-v2 PSBT") + } + if u.Upsbt.TxModifiable&2 == 0 { + return errors.New("outputs are not modifiable in this PSBT") + } + u.Upsbt.Outputs = append(u.Upsbt.Outputs, output) + u.Upsbt.OutputCount = uint32(len(u.Upsbt.Outputs)) + return nil +} diff --git a/btcutil/psbt/utils.go b/btcutil/psbt/utils.go index c47f6afd4d..6cad2d0aa2 100644 --- a/btcutil/psbt/utils.go +++ b/btcutil/psbt/utils.go @@ -295,12 +295,15 @@ func readTxOut(txout []byte) (*wire.TxOut, error) { // UTXO fields of the PSBT. An error is returned if an input is specified that // does not contain any UTXO information. func SumUtxoInputValues(packet *Packet) (int64, error) { - // We take the TX ins of the unsigned TX as the truth for how many - // inputs there should be, as the fields in the extra data part of the - // PSBT can be empty. - if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { - return 0, fmt.Errorf("TX input length doesn't match PSBT " + - "input length") + // For v0 PSBTs we cross-check against the unsigned transaction. + if packet.Version < 2 { + if packet.UnsignedTx == nil { + return 0, fmt.Errorf("v0 PSBT missing unsigned tx") + } + if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { + return 0, fmt.Errorf("TX input length doesn't match " + + "PSBT input length") + } } inputSum := int64(0) @@ -314,17 +317,22 @@ func SumUtxoInputValues(packet *Packet) (int64, error) { // Non-witness UTXOs reference to the whole transaction // the UTXO resides in. utxOuts := in.NonWitnessUtxo.TxOut - txIn := packet.UnsignedTx.TxIn[idx] - // Check that utxOuts actually has enough space to - // contain the previous outpoint's index. - opIdx := txIn.PreviousOutPoint.Index + // For v2, the output index is stored directly in the + // PInput. For v0, it comes from the unsigned tx. + var opIdx uint32 + if packet.Version >= 2 { + opIdx = in.OutputIndex + } else { + opIdx = packet.UnsignedTx.TxIn[idx].PreviousOutPoint.Index + } + if opIdx >= uint32(len(utxOuts)) { return 0, fmt.Errorf("input %d has malformed "+ "TxOut field", idx) } - inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value + inputSum += utxOuts[opIdx].Value default: return 0, fmt.Errorf("input %d has no UTXO information",