diff --git a/go.mod b/go.mod index 2530d47e2..70aba7976 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/decred/dcrd/rpc/jsonrpc/types/v4 v4.4.0 github.com/decred/dcrd/rpcclient/v8 v8.1.0 github.com/decred/dcrd/txscript/v4 v4.1.2 - github.com/decred/dcrd/wire v1.7.2 + github.com/decred/dcrd/wire v1.7.5 github.com/decred/dcrtest/dcrdtest v1.0.1-0.20240404170936-a2529e936df1 github.com/decred/go-socks v1.1.0 github.com/decred/slog v1.2.0 diff --git a/go.sum b/go.sum index 073ead5ed..65ae510e3 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,8 @@ github.com/decred/dcrd/rpcclient/v8 v8.1.0 h1:FLZ1j4ub7+O9oCIcKf+frYCrZW++G3FSzk github.com/decred/dcrd/rpcclient/v8 v8.1.0/go.mod h1:iTHqLrHnS2VLJPHQk7guy0BP3jKvMew9STDqWWhFNA4= github.com/decred/dcrd/txscript/v4 v4.1.2 h1:1EP7ZmBDl2LBeAMTEygxY8rVNN3+lkGqrsb4u64x+II= github.com/decred/dcrd/txscript/v4 v4.1.2/go.mod h1:r5/8qfCnl6TFrE369gggUayVIryM1oC7BLoRfa27Ckw= -github.com/decred/dcrd/wire v1.7.2 h1:04vpHHE3t78rDztjZx82JV2EEOMDUtUUB1347H32kho= -github.com/decred/dcrd/wire v1.7.2/go.mod h1:eP9XRsMloy+phlntkTAaAm611JgLv8NqY1YJoRxkNKU= +github.com/decred/dcrd/wire v1.7.5 h1:fRaaB5CrwYWGI3YVv50XHm54lsU1TB40WnnIJ4W6aGM= +github.com/decred/dcrd/wire v1.7.5/go.mod h1:NZK8QD5W2ObX6p+Q0TUzYNpQtk4Ov3pBIvc6ZUK88FU= github.com/decred/dcrtest/dcrdtest v1.0.1-0.20240404170936-a2529e936df1 h1:RbUvO7dsxdNgb2DvP2/h34eS2Ej1T4a7opzvOzz+YFI= github.com/decred/dcrtest/dcrdtest v1.0.1-0.20240404170936-a2529e936df1/go.mod h1:kbRQzyWu1IfukYZqCioVyJokzu1ifvIXzVDrXReeOsQ= github.com/decred/go-socks v1.1.0 h1:dnENcc0KIqQo3HSXdgboXAHgqsCIutkqq6ntQjYtm2U= diff --git a/wire/bench_test.go b/wire/bench_test.go index 14fad50c3..833eb52f7 100644 --- a/wire/bench_test.go +++ b/wire/bench_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2025 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -14,6 +14,7 @@ import ( "time" "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/crypto/blake256" ) // genesisCoinbaseTx is the coinbase transaction for the genesis blocks for @@ -634,6 +635,28 @@ func BenchmarkTxHash(b *testing.B) { } } +// BenchmarkTxHashReuseHasher performs a benchmark on how long it takes to hash a +// transaction by reusing the *blake256.Hasher256 object. +func BenchmarkTxHashReuseHasher(b *testing.B) { + h := blake256.NewHasher256() + + txHash := func(h *blake256.Hasher256, tx *MsgTx) chainhash.Hash { + txCopy := *tx + txCopy.SerType = TxSerializeNoWitness + err := txCopy.Serialize(h) + if err != nil { + panic(err) + } + return h.Sum256() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + h.Reset() + _ = txHash(h, &genesisCoinbaseTx) + } +} + // BenchmarkHashB performs a benchmark on how long it takes to perform a hash // returning a byte slice. func BenchmarkHashB(b *testing.B) { @@ -676,3 +699,25 @@ func BenchmarkWriteMessageN(b *testing.B) { } } } + +// BenchmarkReadMessageN benchmarks the genesis coinbase deserialization using +// the ReadMessageN function. +func BenchmarkReadMessageN(b *testing.B) { + var buf bytes.Buffer + _, err := WriteMessageN(&buf, &genesisCoinbaseTx, ProtocolVersion, MainNet) + if err != nil { + b.Fatal(err) + } + msgBytes := buf.Bytes() + + r := new(bytes.Reader) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Reset(msgBytes) + _, _, _, err := ReadMessageN(r, ProtocolVersion, MainNet) + if err != nil { + b.Fatalf("ReadMessageN: unexpected error: %v", err) + } + } +} diff --git a/wire/blockheader.go b/wire/blockheader.go index e420194a8..8b2add955 100644 --- a/wire/blockheader.go +++ b/wire/blockheader.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2023 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -11,6 +11,7 @@ import ( "time" "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/crypto/blake256" "lukechampine.com/blake3" ) @@ -88,19 +89,19 @@ type BlockHeader struct { // header. const blockHeaderLen = 180 -// BlockHash computes the block identifier hash for the given block header. +// BlockHash computes the BLAKE-256 block identifier hash for the given block +// header. func (h *BlockHeader) BlockHash() chainhash.Hash { // Encode the header and hash everything prior to the number of // transactions. Ignore the error returns since there is no way the encode // could fail except being out of memory which would cause a run-time panic. - buf := bytes.NewBuffer(make([]byte, 0, MaxBlockHeaderPayload)) - _ = writeBlockHeader(buf, 0, h) - - return chainhash.HashH(buf.Bytes()) + hasher := blake256.NewHasher256() + _ = writeBlockHeader(hasher, 0, h) + return hasher.Sum256() } -// PowHashV1 calculates and returns the version 1 proof of work hash for the -// block header. +// PowHashV1 calculates and returns the version 1 proof of work BLAKE-256 hash +// for the block header. // // NOTE: This is the original proof of work hash function used at Decred launch // and applies to all blocks prior to the activation of DCP0011. @@ -108,16 +109,17 @@ func (h *BlockHeader) PowHashV1() chainhash.Hash { return h.BlockHash() } -// PowHashV2 calculates and returns the version 2 proof of work hash as defined -// in DCP0011 for the block header. +// PowHashV2 calculates and returns the version 2 proof of work BLAKE3 hash as +// defined in DCP0011 for the block header. func (h *BlockHeader) PowHashV2() chainhash.Hash { // Encode the header and hash everything prior to the number of // transactions. Ignore the error returns since there is no way the encode // could fail except being out of memory which would cause a run-time panic. - buf := bytes.NewBuffer(make([]byte, 0, MaxBlockHeaderPayload)) - _ = writeBlockHeader(buf, 0, h) - - return blake3.Sum256(buf.Bytes()) + var digest chainhash.Hash + hasher := blake3.New(len(digest), nil) + _ = writeBlockHeader(hasher, 0, h) + hasher.Sum(digest[:0]) + return digest } // BtcDecode decodes r using the bitcoin protocol encoding into the receiver. diff --git a/wire/common.go b/wire/common.go index 51a84ff64..bd0c89475 100644 --- a/wire/common.go +++ b/wire/common.go @@ -1,11 +1,12 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2024 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package wire import ( + "bytes" "crypto/rand" "encoding/binary" "fmt" @@ -14,6 +15,8 @@ import ( "time" "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/crypto/blake256" + "lukechampine.com/blake3" ) const ( @@ -29,6 +32,10 @@ const ( // strictAsciiRangeUpper is the upper limit of the strict ASCII range. strictAsciiRangeUpper = 0x7e + + // unixToInternal is the number of seconds between year 1 of the Go time + // value and the unix epoch. + unixToInternal = 62135596800 ) var ( @@ -76,104 +83,6 @@ func (l binaryFreeList) Return(buf []byte) { } } -// Uint8 reads a single byte from the provided reader using a buffer from the -// free list and returns it as a uint8. -func (l binaryFreeList) Uint8(r io.Reader) (uint8, error) { - buf := l.Borrow()[:1] - if _, err := io.ReadFull(r, buf); err != nil { - l.Return(buf) - return 0, err - } - rv := buf[0] - l.Return(buf) - return rv, nil -} - -// Uint16 reads two bytes from the provided reader using a buffer from the -// free list, converts it to a number using the provided byte order, and returns -// the resulting uint16. -func (l binaryFreeList) Uint16(r io.Reader, byteOrder binary.ByteOrder) (uint16, error) { - buf := l.Borrow()[:2] - if _, err := io.ReadFull(r, buf); err != nil { - l.Return(buf) - return 0, err - } - rv := byteOrder.Uint16(buf) - l.Return(buf) - return rv, nil -} - -// Uint32 reads four bytes from the provided reader using a buffer from the -// free list, converts it to a number using the provided byte order, and returns -// the resulting uint32. -func (l binaryFreeList) Uint32(r io.Reader, byteOrder binary.ByteOrder) (uint32, error) { - buf := l.Borrow()[:4] - if _, err := io.ReadFull(r, buf); err != nil { - l.Return(buf) - return 0, err - } - rv := byteOrder.Uint32(buf) - l.Return(buf) - return rv, nil -} - -// Uint64 reads eight bytes from the provided reader using a buffer from the -// free list, converts it to a number using the provided byte order, and returns -// the resulting uint64. -func (l binaryFreeList) Uint64(r io.Reader, byteOrder binary.ByteOrder) (uint64, error) { - buf := l.Borrow()[:8] - if _, err := io.ReadFull(r, buf); err != nil { - l.Return(buf) - return 0, err - } - rv := byteOrder.Uint64(buf) - l.Return(buf) - return rv, nil -} - -// PutUint8 copies the provided uint8 into a buffer from the free list and -// writes the resulting byte to the given writer. -func (l binaryFreeList) PutUint8(w io.Writer, val uint8) error { - buf := l.Borrow()[:1] - buf[0] = val - _, err := w.Write(buf) - l.Return(buf) - return err -} - -// PutUint16 serializes the provided uint16 using the given byte order into a -// buffer from the free list and writes the resulting two bytes to the given -// writer. -func (l binaryFreeList) PutUint16(w io.Writer, byteOrder binary.ByteOrder, val uint16) error { - buf := l.Borrow()[:2] - byteOrder.PutUint16(buf, val) - _, err := w.Write(buf) - l.Return(buf) - return err -} - -// PutUint32 serializes the provided uint32 using the given byte order into a -// buffer from the free list and writes the resulting four bytes to the given -// writer. -func (l binaryFreeList) PutUint32(w io.Writer, byteOrder binary.ByteOrder, val uint32) error { - buf := l.Borrow()[:4] - byteOrder.PutUint32(buf, val) - _, err := w.Write(buf) - l.Return(buf) - return err -} - -// PutUint64 serializes the provided uint64 using the given byte order into a -// buffer from the free list and writes the resulting eight bytes to the given -// writer. -func (l binaryFreeList) PutUint64(w io.Writer, byteOrder binary.ByteOrder, val uint64) error { - buf := l.Borrow()[:8] - byteOrder.PutUint64(buf, val) - _, err := w.Write(buf) - l.Return(buf) - return err -} - // binarySerializer provides a free list of buffers to use for serializing and // deserializing primitive integer values to and from io.Readers and io.Writers. var binarySerializer binaryFreeList = make(chan []byte, binaryFreeListMaxItems) @@ -190,9 +99,108 @@ type uint32Time time.Time // int64Time represents a unix timestamp encoded with an int64. It is used as // a way to signal the readElement function how to decode a timestamp into a Go -// time.Time since it is otherwise ambiguous. +// time.Time since it is otherwise ambiguous. The value is rejected if it is +// larger than the maximum usable seconds for a Go time value for worry-free +// comparisons. type int64Time time.Time +// shortRead optimizes short (<= 8 byte) reads from r by special casing +// buffer allocations for specific reader types. +// +// The callback is called with a short buffer of 8 bytes in length, and only +// size bytes should be read from this array. +// +// For longer reads and reads of byte arrays, io.ReadFull should be used +// instead. +// +// This function will panic if called with a size greater than 8. +func shortRead(r io.Reader, size int, cb func(p [8]byte)) error { + var data [8]byte + + switch r := r.(type) { + // wireBuffer is the reader used by ReadMessageN. + case *wireBuffer: + n, _ := r.Read(data[:size]) + if n == 0 { + return io.EOF + } + if n != size { + return io.ErrUnexpectedEOF + } + cb(data) + + // A *bytes.Buffer is a common case of a reader type that callers may + // provide to BtcDecode. + case *bytes.Buffer: + n, _ := r.Read(data[:size]) + if n == 0 { + return io.EOF + } + if n != size { + return io.ErrUnexpectedEOF + } + cb(data) + + // A *bytes.Reader is a common case of a reader type that callers may + // provide to BtcDecode. + case *bytes.Reader: + n, _ := r.Read(data[:size]) + if n == 0 { + return io.EOF + } + if n != size { + return io.ErrUnexpectedEOF + } + cb(data) + + default: + p := binarySerializer.Borrow() + _, err := io.ReadFull(r, p[:size]) + if err != nil { + return err + } + cb(*(*[8]byte)(p)) + binarySerializer.Return(p) + } + + return nil +} + +// readUint8 reads a byte and stores it to *value. +func readUint8(r io.Reader, value *uint8) error { + return shortRead(r, 1, func(p [8]byte) { + *value = p[0] + }) +} + +// readUint16LE reads the little endian encoding of a uint16 and stores it to *value. +func readUint16LE(r io.Reader, value *uint16) error { + return shortRead(r, 2, func(p [8]byte) { + *value = littleEndian.Uint16(p[:]) + }) +} + +// readUint16BE reads the big endian encoding of a uint16 and stores it to *value. +func readUint16BE(r io.Reader, value *uint16) error { + return shortRead(r, 2, func(p [8]byte) { + *value = bigEndian.Uint16(p[:]) + }) +} + +// readUint32LE reads the little endian encoding of a uint32 and stores it to *value. +func readUint32LE(r io.Reader, value *uint32) error { + return shortRead(r, 4, func(p [8]byte) { + *value = littleEndian.Uint32(p[:]) + }) +} + +// readUint64LE reads the little endian encoding of a uint64 and stores it to *value. +func readUint64LE(r io.Reader, value *uint64) error { + return shortRead(r, 8, func(p [8]byte) { + *value = littleEndian.Uint64(p[:]) + }) +} + // readElement reads the next sequence of bytes from r using little endian // depending on the concrete type of element pointed to. func readElement(r io.Reader, element interface{}) error { @@ -200,59 +208,58 @@ func readElement(r io.Reader, element interface{}) error { // type assertions first. switch e := element.(type) { case *uint8: - rv, err := binarySerializer.Uint8(r) + err := readUint8(r, e) if err != nil { return err } - *e = rv return nil case *uint16: - rv, err := binarySerializer.Uint16(r, littleEndian) + err := readUint16LE(r, e) if err != nil { return err } - *e = rv return nil case *int32: - rv, err := binarySerializer.Uint32(r, littleEndian) + var value uint32 + err := readUint32LE(r, &value) if err != nil { return err } - *e = int32(rv) + *e = int32(value) return nil case *uint32: - rv, err := binarySerializer.Uint32(r, littleEndian) + err := readUint32LE(r, e) if err != nil { return err } - *e = rv return nil case *int64: - rv, err := binarySerializer.Uint64(r, littleEndian) + var value uint64 + err := readUint64LE(r, &value) if err != nil { return err } - *e = int64(rv) + *e = int64(value) return nil case *uint64: - rv, err := binarySerializer.Uint64(r, littleEndian) + err := readUint64LE(r, e) if err != nil { return err } - *e = rv return nil case *bool: - rv, err := binarySerializer.Uint8(r) + var value uint8 + err := readUint8(r, &value) if err != nil { return err } - if rv == 0x00 { + if value == 0x00 { *e = false } else { *e = true @@ -261,20 +268,29 @@ func readElement(r io.Reader, element interface{}) error { // Unix timestamp encoded as a uint32. case *uint32Time: - rv, err := binarySerializer.Uint32(r, binary.LittleEndian) + var ts uint32 + err := readUint32LE(r, &ts) if err != nil { return err } - *e = uint32Time(time.Unix(int64(rv), 0)) + *e = uint32Time(time.Unix(int64(ts), 0)) return nil // Unix timestamp encoded as an int64. case *int64Time: - rv, err := binarySerializer.Uint64(r, binary.LittleEndian) + var ts uint64 + err := readUint64LE(r, &ts) if err != nil { return err } - *e = int64Time(time.Unix(int64(rv), 0)) + + // Reject timestamps that would overflow the maximum usable number of + // seconds for worry-free comparisons. + if ts > math.MaxInt64-unixToInternal { + const str = "timestamp exceeds maximum allowed value" + return messageError("readElement", ErrInvalidTimestamp, str) + } + *e = int64Time(time.Unix(int64(ts), 0)) return nil // Message header checksum. @@ -292,14 +308,6 @@ func readElement(r io.Reader, element interface{}) error { } return nil - // Message header command. - case *[CommandSize]uint8: - _, err := io.ReadFull(r, e[:]) - if err != nil { - return err - } - return nil - // IP address. case *[16]byte: _, err := io.ReadFull(r, e[:]) @@ -363,35 +371,31 @@ func readElement(r io.Reader, element interface{}) error { return nil case *ServiceFlag: - rv, err := binarySerializer.Uint64(r, littleEndian) + err := readUint64LE(r, (*uint64)(e)) if err != nil { return err } - *e = ServiceFlag(rv) return nil case *InvType: - rv, err := binarySerializer.Uint32(r, littleEndian) + err := readUint32LE(r, (*uint32)(e)) if err != nil { return err } - *e = InvType(rv) return nil case *CurrencyNet: - rv, err := binarySerializer.Uint32(r, littleEndian) + err := readUint32LE(r, (*uint32)(e)) if err != nil { return err } - *e = CurrencyNet(rv) return nil case *RejectCode: - rv, err := binarySerializer.Uint8(r) + err := readUint8(r, (*uint8)(e)) if err != nil { return err } - *e = RejectCode(rv) return nil } @@ -412,68 +416,154 @@ func readElements(r io.Reader, elements ...interface{}) error { return nil } +// shortWrite optimizes short (<= 8 byte) writes to w by special casing +// buffer allocations for specific writer types. +// +// The callback returns a short buffer to 8 bytes in length and a size +// specifying how much of the buffer to write. +// +// For longer writes and writes of byte arrays, dynamic dispatch to w.Write +// should be used instead. +func shortWrite(w io.Writer, cb func() (data [8]byte, size int)) error { + data, size := cb() + + switch w := w.(type) { + // The most common case (called through WriteMessageN) is that the writer is a + // *bytes.Buffer. Optimize for that case by appending binary serializations + // to its existing capacity instead of paying the synchronization cost to + // serialize to temporary buffers pulled from the binary freelist. + case *bytes.Buffer: + w.Write(data[:size]) + return nil + + // Hashing transactions can be optimized by writing directly to the + // BLAKE-256 hasher. + case *blake256.Hasher256: + w.Write(data[:size]) + return nil + + // Hashing block headers can be optimized by writing directly to the + // BLAKE-3 hasher. + case *blake3.Hasher: + w.Write(data[:size]) + return nil + + default: + p := binarySerializer.Borrow()[:size] + copy(p, data[:size]) + _, err := w.Write(p) + return err + } +} + +// writeUint8 writes the byte value to the writer. +func writeUint8(w io.Writer, value uint8) error { + return shortWrite(w, func() (buf [8]byte, size int) { + buf[0] = value + return buf, 1 + }) +} + +// writeUint16LE writes the little endian encoding of value to the writer. +func writeUint16LE(w io.Writer, value uint16) error { + return shortWrite(w, func() (buf [8]byte, size int) { + littleEndian.PutUint16(buf[:], value) + return buf, 2 + }) +} + +// writeUint16BE writes the big endian encoding of value to the writer. +func writeUint16BE(w io.Writer, value uint16) error { + return shortWrite(w, func() (buf [8]byte, size int) { + bigEndian.PutUint16(buf[:], value) + return buf, 2 + }) +} + +// writeUint32LE writes the little endian encoding of value to the writer. +func writeUint32LE(w io.Writer, value uint32) error { + return shortWrite(w, func() (buf [8]byte, size int) { + littleEndian.PutUint32(buf[:], value) + return buf, 4 + }) +} + +// writeUint64LE writes the little endian encoding of value to the writer. +func writeUint64LE(w io.Writer, value uint64) error { + return shortWrite(w, func() (buf [8]byte, size int) { + littleEndian.PutUint64(buf[:], value) + return buf, 8 + }) +} + // writeElement writes the little endian representation of element to w. func writeElement(w io.Writer, element interface{}) error { // Attempt to write the element based on the concrete type via fast // type assertions first. switch e := element.(type) { case *uint8: - err := binarySerializer.PutUint8(w, *e) + err := writeUint8(w, *e) if err != nil { return err } return nil case *uint16: - err := binarySerializer.PutUint16(w, littleEndian, *e) + err := writeUint16LE(w, *e) if err != nil { return err } return nil case *int32: - err := binarySerializer.PutUint32(w, littleEndian, uint32(*e)) + err := writeUint32LE(w, uint32(*e)) if err != nil { return err } return nil case *uint32: - err := binarySerializer.PutUint32(w, littleEndian, *e) + err := writeUint32LE(w, *e) if err != nil { return err } return nil case *int64: - err := binarySerializer.PutUint64(w, littleEndian, uint64(*e)) + err := writeUint64LE(w, uint64(*e)) if err != nil { return err } return nil case *uint64: - err := binarySerializer.PutUint64(w, littleEndian, *e) + err := writeUint64LE(w, *e) if err != nil { return err } return nil - case *bool: - var err error - if *e { - err = binarySerializer.PutUint8(w, 0x01) - } else { - err = binarySerializer.PutUint8(w, 0x00) + case *int64Time: + // Reject timestamps that would overflow the maximum usable number of + // seconds for worry-free comparisons. + secs := uint64(time.Time(*e).Unix()) + if secs > math.MaxInt64-unixToInternal { + const str = "timestamp exceeds maximum allowed value" + return messageError("writeElement", ErrInvalidTimestamp, str) } + err := writeUint64LE(w, secs) if err != nil { return err } return nil - // Message header checksum. - case *[4]byte: - _, err := w.Write(e[:]) + case *bool: + var err error + if *e { + err = writeUint8(w, 0x01) + } else { + err = writeUint8(w, 0x00) + } if err != nil { return err } @@ -487,14 +577,6 @@ func writeElement(w io.Writer, element interface{}) error { } return nil - // Message header command. - case *[CommandSize]uint8: - _, err := w.Write(e[:]) - if err != nil { - return err - } - return nil - // IP address. case *[16]byte: _, err := w.Write(e[:]) @@ -558,28 +640,28 @@ func writeElement(w io.Writer, element interface{}) error { return nil case *ServiceFlag: - err := binarySerializer.PutUint64(w, littleEndian, uint64(*e)) + err := writeUint64LE(w, uint64(*e)) if err != nil { return err } return nil case *InvType: - err := binarySerializer.PutUint32(w, littleEndian, uint32(*e)) + err := writeUint32LE(w, uint32(*e)) if err != nil { return err } return nil case *CurrencyNet: - err := binarySerializer.PutUint32(w, littleEndian, uint32(*e)) + err := writeUint32LE(w, uint32(*e)) if err != nil { return err } return nil case *RejectCode: - err := binarySerializer.PutUint8(w, uint8(*e)) + err := writeUint8(w, uint8(*e)) if err != nil { return err } @@ -606,7 +688,8 @@ func writeElements(w io.Writer, elements ...interface{}) error { // ReadVarInt reads a variable length integer from r and returns it as a uint64. func ReadVarInt(r io.Reader, pver uint32) (uint64, error) { const op = "ReadVarInt" - discriminant, err := binarySerializer.Uint8(r) + var discriminant uint8 + err := readUint8(r, &discriminant) if err != nil { return 0, err } @@ -614,7 +697,8 @@ func ReadVarInt(r io.Reader, pver uint32) (uint64, error) { var rv uint64 switch discriminant { case 0xff: - sv, err := binarySerializer.Uint64(r, littleEndian) + var sv uint64 + err := readUint64LE(r, &sv) if err != nil { return 0, err } @@ -629,7 +713,8 @@ func ReadVarInt(r io.Reader, pver uint32) (uint64, error) { } case 0xfe: - sv, err := binarySerializer.Uint32(r, littleEndian) + var sv uint32 + err := readUint32LE(r, &sv) if err != nil { return 0, err } @@ -644,7 +729,8 @@ func ReadVarInt(r io.Reader, pver uint32) (uint64, error) { } case 0xfd: - sv, err := binarySerializer.Uint16(r, littleEndian) + var sv uint16 + err := readUint16LE(r, &sv) if err != nil { return 0, err } @@ -669,30 +755,31 @@ func ReadVarInt(r io.Reader, pver uint32) (uint64, error) { // on its value. func WriteVarInt(w io.Writer, pver uint32, val uint64) error { if val < 0xfd { - return binarySerializer.PutUint8(w, uint8(val)) + return writeUint8(w, uint8(val)) } if val <= math.MaxUint16 { - err := binarySerializer.PutUint8(w, 0xfd) - if err != nil { - return err - } - return binarySerializer.PutUint16(w, littleEndian, uint16(val)) + return shortWrite(w, func() (p [8]byte, size int) { + p[0] = 0xfd + littleEndian.PutUint16(p[1:], uint16(val)) + return p, 3 + }) } if val <= math.MaxUint32 { - err := binarySerializer.PutUint8(w, 0xfe) - if err != nil { - return err - } - return binarySerializer.PutUint32(w, littleEndian, uint32(val)) + return shortWrite(w, func() (p [8]byte, size int) { + p[0] = 0xfe + littleEndian.PutUint32(p[1:], uint32(val)) + return p, 5 + }) } - err := binarySerializer.PutUint8(w, 0xff) + // shortWrite is not designed for writes > 8 bytes. + err := writeUint8(w, 0xff) if err != nil { return err } - return binarySerializer.PutUint64(w, littleEndian, val) + return writeUint64LE(w, val) } // VarIntSerializeSize returns the number of bytes it would take to serialize @@ -798,7 +885,15 @@ func WriteVarString(w io.Writer, pver uint32, str string) error { if err != nil { return err } - _, err = w.Write([]byte(str)) + + switch w := w.(type) { + case *bytes.Buffer: + _, err = w.WriteString(str) + case *blake256.Hasher256: + w.WriteString(str) + default: + _, err = w.Write([]byte(str)) + } return err } @@ -851,7 +946,8 @@ func WriteVarBytes(w io.Writer, pver uint32, bytes []byte) error { // unexported version takes a reader primarily to ensure the error paths // can be properly tested by passing a fake reader in the tests. func randomUint64(r io.Reader) (uint64, error) { - rv, err := binarySerializer.Uint64(r, bigEndian) + var rv uint64 + err := readUint64LE(r, &rv) if err != nil { return 0, err } diff --git a/wire/common_test.go b/wire/common_test.go index 0f0bb0d81..5fbd681a1 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2025 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -7,11 +7,15 @@ package wire import ( "bytes" + "encoding/binary" + "encoding/hex" "errors" "io" "reflect" "strings" "testing" + "testing/iotest" + "time" "github.com/davecgh/go-spew/spew" "github.com/decred/dcrd/chaincfg/chainhash" @@ -52,6 +56,18 @@ func (r *fakeRandReader) Read(p []byte) (int, error) { return n, r.err } +// hexToBytes converts the passed hex string into bytes and will panic if there +// is an error. This is only provided for the hard-coded constants so errors in +// the source code can be detected. It will only (and must only) be called with +// hard-coded values. +func hexToBytes(s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + panic("invalid hex in source file: " + s) + } + return b +} + func newInt32(v int32) *int32 { return &v } @@ -89,52 +105,29 @@ func newCurrencyNet(v CurrencyNet) *CurrencyNet { // type assertions to avoid reflection when possible. func TestElementWire(t *testing.T) { type writeElementReflect int32 + newUint8 := func(v uint8) *uint8 { return &v } + newUint16 := func(v uint16) *uint16 { return &v } + newInt64Time := func(v time.Time) *int64Time { return (*int64Time)(&v) } tests := []struct { in interface{} // Value to encode buf []byte // Wire encoding }{ - {newInt32(1), []byte{0x01, 0x00, 0x00, 0x00}}, - {newUint32(256), []byte{0x00, 0x01, 0x00, 0x00}}, - { - newInt64(65536), - []byte{0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - { - newUint64(4294967296), - []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, - }, - { - newBool(true), - []byte{0x01}, - }, - { - newBool(false), - []byte{0x00}, - }, - { - &[4]byte{0x01, 0x02, 0x03, 0x04}, - []byte{0x01, 0x02, 0x03, 0x04}, - }, - { - &[CommandSize]byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, - }, - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, - }, - }, + {newUint8(240), hexToBytes("f0")}, + {newUint16(61423), hexToBytes("efef")}, + {newInt32(1), hexToBytes("01000000")}, + {newUint32(256), hexToBytes("00010000")}, + {newInt64(65536), hexToBytes("0000010000000000")}, + {newUint64(4294967296), hexToBytes("0000000001000000")}, + {newBool(true), hexToBytes("01")}, + {newBool(false), hexToBytes("00")}, + {newInt64Time(time.Unix(1772075804, 0)), hexToBytes("1cbb9f6900000000")}, { &[16]byte{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, }, - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, - }, + hexToBytes("0102030405060708090a0b0c0d0e0f10"), }, { (*chainhash.Hash)(&[chainhash.HashSize]byte{ // Make go vet happy. @@ -143,30 +136,14 @@ func TestElementWire(t *testing.T) { 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, }), - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, - 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, - 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, - }, - }, - { - newServiceFlag(SFNodeNetwork), - []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - { - newInvType(InvTypeTx), - []byte{0x01, 0x00, 0x00, 0x00}, - }, - { - newCurrencyNet(MainNet), - []byte{0xf9, 0x00, 0xb4, 0xd9}, + hexToBytes("0102030405060708090a0b0c0d0e0f101112131415161718191a" + + "1b1c1d1e1f20"), }, + {newServiceFlag(SFNodeNetwork), hexToBytes("0100000000000000")}, + {newInvType(InvTypeTx), hexToBytes("01000000")}, + {newCurrencyNet(MainNet), hexToBytes("f900b4d9")}, // Type not supported by the "fast" path and requires reflection. - { - writeElementReflect(1), - []byte{0x01, 0x00, 0x00, 0x00}, - }, + {writeElementReflect(1), hexToBytes("01000000")}, } t.Logf("Running %d tests", len(tests)) @@ -204,6 +181,27 @@ func TestElementWire(t *testing.T) { spew.Sdump(ival), spew.Sdump(test.in)) continue } + + // Read from wire format again, but this time with a one byte reader. + obr := iotest.OneByteReader(bytes.NewReader(test.buf)) + val = test.in + if reflect.ValueOf(test.in).Kind() != reflect.Ptr { + val = reflect.New(reflect.TypeOf(test.in)).Interface() + } + err = readElement(obr, val) + if err != nil { + t.Errorf("readElement #%d error %v", i, err) + continue + } + ival = val + if reflect.ValueOf(test.in).Kind() != reflect.Ptr { + ival = reflect.Indirect(reflect.ValueOf(val)).Interface() + } + if !reflect.DeepEqual(ival, test.in) { + t.Errorf("readElement #%d\n got: %s want: %s", i, spew.Sdump(ival), + spew.Sdump(test.in)) + continue + } } } @@ -275,6 +273,79 @@ func TestElementWireErrors(t *testing.T) { } } +// TestShortReads ensures that all short reads work as expected with the various +// supported readers and a couple of readers for the default path including a +// one byte reader. +func TestShortReads(t *testing.T) { + var buf bytes.Buffer + buf.WriteByte(0x05) + binary.Write(&buf, binary.LittleEndian, uint16(61355)) + binary.Write(&buf, binary.BigEndian, uint16(61355)) + binary.Write(&buf, binary.LittleEndian, uint32(16777216)) + binary.Write(&buf, binary.LittleEndian, uint64(8589934592)) + testWithReader := func(r io.Reader) { + t.Helper() + + // Ensure readUint8 produces the expected value with no errors. + var u8 uint8 + err := readUint8(r, &u8) + if err != nil { + t.Fatalf("%T: readUint8 err: %v", r, err) + } + if u8 != 5 { + t.Fatalf("%T: readUint8 val: got %v, want %v", r, u8, 5) + } + + // Ensure readUint16LE produces the expected value with no errors. + var u16 uint16 + err = readUint16LE(r, &u16) + if err != nil { + t.Fatalf("%T: readUint16LE err: %v", r, err) + } + if u16 != 61355 { + t.Fatalf("%T: readUint16LE val: got %v, want %v", r, u16, + 61355) + } + + // Ensure readUint16BE produces the expected value with no errors. + err = readUint16BE(r, &u16) + if err != nil { + t.Fatalf("%T: readUint16BE err: %v", r, err) + } + if u16 != 61355 { + t.Fatalf("%T: readUint16BE val: got %v, want %v", r, u16, + 61355) + } + + // Ensure readUint32LE produces the expected value with no errors. + var u32 uint32 + err = readUint32LE(r, &u32) + if err != nil { + t.Fatalf("%T: readUint32LE err: %v", r, err) + } + if u32 != 16777216 { + t.Fatalf("%T: readUint32LE val: got %v, want %v", r, u32, + 16777216) + } + + // Ensure readUint64LE produces the expected value with no errors. + var u64 uint64 + err = readUint64LE(r, &u64) + if err != nil { + t.Fatalf("%T: readUint64LE err: %v", r, err) + } + if u64 != 8589934592 { + t.Fatalf("%T: readUint64LE val: got %v, want %v", r, u64, + 8589934592) + } + } + testWithReader(bytes.NewBuffer(buf.Bytes())) + testWithReader((*wireBuffer)(bytes.NewBuffer(buf.Bytes()))) + testWithReader(bytes.NewReader(buf.Bytes())) + testWithReader(io.LimitReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))) + testWithReader(iotest.OneByteReader(&buf)) +} + // TestVarIntWire tests wire encode and decode for variable length integers. func TestVarIntWire(t *testing.T) { pver := ProtocolVersion diff --git a/wire/error.go b/wire/error.go index cdb62800a..a7b6f1f66 100644 --- a/wire/error.go +++ b/wire/error.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2015 The btcsuite developers -// Copyright (c) 2015-2025 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -48,6 +48,10 @@ const ( // is received. ErrPayloadChecksum + // ErrTrailingBytes is returned when a message is received that is valid + // enough to fully decode, but also contains additional trailing bytes. + ErrTrailingBytes + // ErrTooManyAddrs is returned when an address list exceeds the maximum // allowed. ErrTooManyAddrs @@ -153,6 +157,22 @@ const ( // ErrTooManyCFilters is returned when the number of committed filters // exceeds the maximum allowed in a batch. ErrTooManyCFilters + + // ErrTooFewAddrs is returned when an address list contains fewer addresses + // than the minimum required. + ErrTooFewAddrs + + // ErrUnknownNetAddrType is returned when a network address type is not + // recognized or supported. + ErrUnknownNetAddrType + + // ErrInvalidTimestamp is returned when a message that involves a timestamp + // is not in the allowable range. + ErrInvalidTimestamp + + // numErrorCodes is the total number of error codes defined above. This + // entry MUST be the last entry in the enum. + numErrorCodes ) // Map of ErrorCode values back to their constant names for pretty printing. @@ -166,6 +186,7 @@ var errorCodeStrings = map[ErrorCode]string{ ErrMalformedCmd: "ErrMalformedCmd", ErrUnknownCmd: "ErrUnknownCmd", ErrPayloadChecksum: "ErrPayloadChecksum", + ErrTrailingBytes: "ErrTrailingBytes", ErrTooManyAddrs: "ErrTooManyAddrs", ErrTooManyTxs: "ErrTooManyTxs", ErrMsgInvalidForPVer: "ErrMsgInvalidForPVer", @@ -193,6 +214,9 @@ var errorCodeStrings = map[ErrorCode]string{ ErrTooManyMixPairReqUTXOs: "ErrTooManyMixPairReqUTXOs", ErrTooManyPrevMixMsgs: "ErrTooManyPrevMixMsgs", ErrTooManyCFilters: "ErrTooManyCFilters", + ErrTooFewAddrs: "ErrTooFewAddrs", + ErrUnknownNetAddrType: "ErrUnknownNetAddrType", + ErrInvalidTimestamp: "ErrInvalidTimestamp", } // String returns the ErrorCode as a human-readable name. diff --git a/wire/error_test.go b/wire/error_test.go index 452042504..8cbe13d5b 100644 --- a/wire/error_test.go +++ b/wire/error_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2017 The btcsuite developers -// Copyright (c) 2015-2024 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -28,6 +28,7 @@ func TestMessageErrorCodeStringer(t *testing.T) { {ErrMalformedCmd, "ErrMalformedCmd"}, {ErrUnknownCmd, "ErrUnknownCmd"}, {ErrPayloadChecksum, "ErrPayloadChecksum"}, + {ErrTrailingBytes, "ErrTrailingBytes"}, {ErrTooManyAddrs, "ErrTooManyAddrs"}, {ErrTooManyTxs, "ErrTooManyTxs"}, {ErrMsgInvalidForPVer, "ErrMsgInvalidForPVer"}, @@ -55,10 +56,19 @@ func TestMessageErrorCodeStringer(t *testing.T) { {ErrTooManyMixPairReqUTXOs, "ErrTooManyMixPairReqUTXOs"}, {ErrTooManyPrevMixMsgs, "ErrTooManyPrevMixMsgs"}, {ErrTooManyCFilters, "ErrTooManyCFilters"}, + {ErrTooFewAddrs, "ErrTooFewAddrs"}, + {ErrUnknownNetAddrType, "ErrUnknownNetAddrType"}, + {ErrInvalidTimestamp, "ErrInvalidTimestamp"}, {0xffff, "Unknown ErrorCode (65535)"}, } + // Detect additional defines that don't have the stringer added. + if len(tests)-1 != int(numErrorCodes) { + t.Fatal("It appears an error code was added without adding an " + + "associated stringer test") + } + t.Logf("Running %d tests", len(tests)) for i, test := range tests { result := test.in.String() diff --git a/wire/go.mod b/wire/go.mod index 85fa1d68b..1884d72a4 100644 --- a/wire/go.mod +++ b/wire/go.mod @@ -5,10 +5,10 @@ go 1.17 require ( github.com/davecgh/go-spew v1.1.1 github.com/decred/dcrd/chaincfg/chainhash v1.0.5 + github.com/decred/dcrd/crypto/blake256 v1.1.0 lukechampine.com/blake3 v1.3.0 ) -require ( - github.com/decred/dcrd/crypto/blake256 v1.1.0 // indirect - github.com/klauspost/cpuid/v2 v2.0.9 // indirect -) +require github.com/klauspost/cpuid/v2 v2.0.9 // indirect + +retract v1.7.3 // Short read errors diff --git a/wire/invvect_test.go b/wire/invvect_test.go index 751e718e3..0c944c7fc 100644 --- a/wire/invvect_test.go +++ b/wire/invvect_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2024 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. diff --git a/wire/message.go b/wire/message.go index b86c09aaf..edfad633f 100644 --- a/wire/message.go +++ b/wire/message.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2024 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -103,11 +103,11 @@ type Message interface { // makeEmptyMessage creates a message of the appropriate concrete type based // on the command. -func makeEmptyMessage(command string) (Message, error) { +func makeEmptyMessage(command []byte) (Message, error) { const op = "makeEmptyMessage" var msg Message - switch command { + switch string(command) { case CmdVersion: msg = &MsgVersion{} @@ -226,133 +226,94 @@ func makeEmptyMessage(command string) (Message, error) { msg = &MsgCFiltersV2{} default: - str := fmt.Sprintf("unhandled command [%s]", command) + str := fmt.Sprintf("unhandled command [%s]", string(command)) return nil, messageError(op, ErrUnknownCmd, str) } return msg, nil } -// messageHeader defines the header structure for all Decred protocol messages. -type messageHeader struct { - magic CurrencyNet // 4 bytes - command string // 12 bytes - length uint32 // 4 bytes - checksum [4]byte // 4 bytes -} - -// readMessageHeader reads a Decred message header from r. -func readMessageHeader(r io.Reader) (int, *messageHeader, error) { - // Since readElements doesn't return the amount of bytes read, attempt - // to read the entire header into a buffer first in case there is a - // short read so the proper amount of read bytes are known. This works - // since the header is a fixed size. - var headerBytes [MessageHeaderSize]byte - n, err := io.ReadFull(r, headerBytes[:]) - if err != nil { - return n, nil, err - } - hr := bytes.NewReader(headerBytes[:]) - - // Create and populate a messageHeader struct from the raw header bytes. - hdr := messageHeader{} - var command [CommandSize]byte - readElements(hr, &hdr.magic, &command, &hdr.length, &hdr.checksum) - - // Strip trailing zeros from command string. - hdr.command = string(bytes.TrimRight(command[:], string(rune(0)))) - - return n, &hdr, nil -} - -// discardInput reads n bytes from reader r in chunks and discards the read -// bytes. This is used to skip payloads when various errors occur and helps -// prevent rogue nodes from causing massive memory allocation through forging -// header length. -func discardInput(r io.Reader, n uint32) { - maxSize := uint32(10 * 1024) // 10k at a time - numReads := n / maxSize - bytesRemaining := n % maxSize - if n > 0 { - buf := make([]byte, maxSize) - for i := uint32(0); i < numReads; i++ { - io.ReadFull(r, buf) - } - } - if bytesRemaining > 0 { - buf := make([]byte, bytesRemaining) - io.ReadFull(r, buf) - } -} - // WriteMessageN writes a Decred Message to w including the necessary header // information and returns the number of bytes written. This function is the // same as WriteMessage except it also returns the number of bytes written. func WriteMessageN(w io.Writer, msg Message, pver uint32, dcrnet CurrencyNet) (int, error) { const op = "WriteMessage" - totalBytes := 0 - var elems struct { - dcrnet CurrencyNet + var ( command [CommandSize]byte lenp uint32 checksum [4]byte - } - elems.dcrnet = dcrnet + ) // Enforce max command size. cmd := msg.Command() if len(cmd) > CommandSize { msg := fmt.Sprintf("command [%s] is too long [max %v]", cmd, CommandSize) - return totalBytes, messageError(op, ErrCmdTooLong, msg) + return 0, messageError(op, ErrCmdTooLong, msg) + } + copy(command[:], []byte(cmd)) + + // Allocate enough buffer space for the entire message size if it is + // known. When it is not known, use an extra size hint of 64 bytes, + // which matches the default small allocation size of a bytes.Buffer + // as of Go 1.25. + extraCap := 64 + switch msg := msg.(type) { + case interface{ SerializeSize() int }: + extraCap = msg.SerializeSize() } - copy(elems.command[:], []byte(cmd)) + + // Initialize buffer with zeroed bytes for the message header (to be + // filled in, with checksum, after appending the payload + // serialization), plus additional capacity for writing the payload. + buf := bytes.NewBuffer(make([]byte, MessageHeaderSize, MessageHeaderSize+extraCap)) // Encode the message payload. - var bw bytes.Buffer - err := msg.BtcEncode(&bw, pver) + err := msg.BtcEncode(buf, pver) if err != nil { - return totalBytes, err + return 0, err } - payload := bw.Bytes() + bufBytes := buf.Bytes() + payload := bufBytes[MessageHeaderSize:] // Enforce maximum overall message payload. if len(payload) > MaxMessagePayload { msg := fmt.Sprintf("message payload is too large - encoded "+ "%d bytes, but maximum message payload is %d bytes", len(payload), MaxMessagePayload) - return totalBytes, messageError(op, ErrPayloadTooLarge, msg) + return 0, messageError(op, ErrPayloadTooLarge, msg) } - elems.lenp = uint32(len(payload)) + lenp = uint32(len(payload)) // Enforce maximum message payload based on the message type. mpl := msg.MaxPayloadLength(pver) - if elems.lenp > mpl { + if lenp > mpl { str := fmt.Sprintf("message payload is too large - encoded "+ "%d bytes, but maximum message payload size for "+ - "messages of type [%s] is %d.", elems.lenp, cmd, mpl) - return totalBytes, messageError(op, ErrPayloadTooLarge, str) + "messages of type [%s] is %d.", lenp, cmd, mpl) + return 0, messageError(op, ErrPayloadTooLarge, str) } - // Encode the header for the message. This is done to a buffer - // rather than directly to the writer since writeElements doesn't - // return the number of bytes written. + // Encode the message header. cksumHash := chainhash.HashH(payload) - copy(elems.checksum[:], cksumHash[0:4]) - var buf [MessageHeaderSize]byte - hw := bytes.NewBuffer(buf[:0]) - writeElements(hw, &elems.dcrnet, &elems.command, &elems.lenp, &elems.checksum) - - // Write header. - n, err := w.Write(hw.Bytes()) - totalBytes += n - if err != nil { - return totalBytes, err + copy(checksum[:], cksumHash[0:4]) + buf.Reset() + writeUint32LE(buf, uint32(dcrnet)) + buf.Write(command[:]) + writeUint32LE(buf, lenp) + buf.Write(checksum[:]) + if buf.Len() != MessageHeaderSize { + // The length of data written for the header is always + // constant, is not dependent on the message being serialized, + // and any implementation errors that cause an incorrect + // length to be written would be discovered by tests. + str := fmt.Sprintf("wrote unexpected message header length - "+ + "encoded %d bytes, but message header size is %d.", + buf.Len(), MessageHeaderSize) + panic(str) } - // Write payload. - n, err = w.Write(payload) - totalBytes += n - return totalBytes, err + // Write header + payload. + return w.Write(bufBytes) } // WriteMessage writes a Decred Message to w including the necessary header @@ -365,6 +326,36 @@ func WriteMessage(w io.Writer, msg Message, pver uint32, dcrnet CurrencyNet) err return err } +// wireBuffer is a bytes.Buffer uniquely used by ReadMessageN. The distinct +// type is used to optimize reads of transaction scripts by avoiding the +// scriptPool freelist when the buffer is known to not be clobbered by the +// caller. +type wireBuffer bytes.Buffer + +func (b *wireBuffer) Bytes() []byte { + return (*bytes.Buffer)(b).Bytes() +} + +func (b *wireBuffer) Grow(n int) { + (*bytes.Buffer)(b).Grow(n) +} + +func (b *wireBuffer) Len() int { + return (*bytes.Buffer)(b).Len() +} + +func (b *wireBuffer) Next(n int) []byte { + return (*bytes.Buffer)(b).Next(n) +} + +func (b *wireBuffer) Read(p []byte) (int, error) { + return (*bytes.Buffer)(b).Read(p) +} + +func (b *wireBuffer) ReadFrom(r io.Reader) (int64, error) { + return (*bytes.Buffer)(b).ReadFrom(r) +} + // ReadMessageN reads, validates, and parses the next Decred Message from r for // the provided protocol version and Decred network. It returns the number of // bytes read in addition to the parsed Message and raw bytes which comprise the @@ -373,39 +364,75 @@ func WriteMessage(w io.Writer, msg Message, pver uint32, dcrnet CurrencyNet) err func ReadMessageN(r io.Reader, pver uint32, dcrnet CurrencyNet) (int, Message, []byte, error) { const op = "ReadMessage" totalBytes := 0 - n, hdr, err := readMessageHeader(r) - totalBytes += n + + lr := &io.LimitedReader{R: r} + + // Read the bytes of the message header to the unread portion of a + // buffer (with some additional extra capacity to read short payloads + // without a realloc). + buf := (*wireBuffer)(bytes.NewBuffer(make([]byte, 0, bytes.MinRead*2))) + lr.N = MessageHeaderSize + read, err := buf.ReadFrom(lr) + totalBytes += int(read) + if lr.N > 0 { + err = io.EOF + } if err != nil { return totalBytes, nil, nil, err } + // Read the message header from the buffer. + // This should consume all of the current buffer's length. + var ( + magic CurrencyNet + command [CommandSize]byte + payloadLen uint32 + checksum [4]byte + ) + readUint32LE(buf, (*uint32)(&magic)) + buf.Read(command[:]) + readUint32LE(buf, &payloadLen) + // Only check the final header field read for error. + _, err = buf.Read(checksum[:]) + // The correct header length has already been read from the input + // reader to the buffer. This length is a constant and would not + // change based on the message or inputs. Any read errors or + // remaining unread bytes in the buffer would be discovered by tests. + if err != nil { + str := fmt.Sprintf("unexpected read error deserializing message "+ + "header: %v", err) + panic(str) + } + if buf.Len() != 0 { + str := fmt.Sprintf("read unexpected message header length - "+ + "%d unread message header bytes remaining", buf.Len()) + panic(str) + } + // Enforce maximum message payload. - if hdr.length > MaxMessagePayload { + if payloadLen > MaxMessagePayload { msg := fmt.Sprintf("message payload is too large - header "+ "indicates %d bytes, but max message payload is %d bytes.", - hdr.length, MaxMessagePayload) + payloadLen, MaxMessagePayload) return totalBytes, nil, nil, messageError(op, ErrPayloadTooLarge, msg) } // Check for messages from the wrong Decred network. - if hdr.magic != dcrnet { - discardInput(r, hdr.length) - msg := fmt.Sprintf("message from other network [%v]", hdr.magic) + if magic != dcrnet { + msg := fmt.Sprintf("message from other network [%v]", magic) return totalBytes, nil, nil, messageError(op, ErrWrongNetwork, msg) } // Check for malformed commands. - command := hdr.command - if !isStrictAscii(command) { - discardInput(r, hdr.length) - msg := fmt.Sprintf("invalid command %v", []byte(command)) + trimmedCommand := bytes.TrimRight(command[:], string(rune(0))) + if !isStrictAscii(string(trimmedCommand)) { + msg := fmt.Sprintf("invalid command %q", string(trimmedCommand)) return totalBytes, nil, nil, messageError(op, ErrMalformedCmd, msg) } // Create struct of appropriate message type based on the command. - msg, err := makeEmptyMessage(command) + msg, err := makeEmptyMessage(trimmedCommand) if err != nil { - discardInput(r, hdr.length) return totalBytes, nil, nil, err } @@ -413,38 +440,59 @@ func ReadMessageN(r io.Reader, pver uint32, dcrnet CurrencyNet) (int, Message, [ // could otherwise create a well-formed header and set the length to max // numbers in order to exhaust the machine's memory. mpl := msg.MaxPayloadLength(pver) - if hdr.length > mpl { - discardInput(r, hdr.length) + if payloadLen > mpl { msg := fmt.Sprintf("payload exceeds max length - header "+ "indicates %v bytes, but max payload size for messages of "+ - "type [%v] is %v.", hdr.length, command, mpl) + "type [%v] is %v.", payloadLen, msg.Command(), mpl) return totalBytes, nil, nil, messageError(op, ErrPayloadTooLarge, msg) } - // Read payload. - payload := make([]byte, hdr.length) - n, err = io.ReadFull(r, payload) - totalBytes += n + // Read payload into unread portion of the buffer. + grow := int(payloadLen) + if grow < bytes.MinRead { + grow = bytes.MinRead + } + buf.Grow(grow) + lr.N = int64(payloadLen) + read, err = buf.ReadFrom(lr) + totalBytes += int(read) + if lr.N > 0 { + err = io.EOF + } if err != nil { return totalBytes, nil, nil, err } + // The Buffer.Bytes documentation states that this slice is not valid + // after the next read, however, the buffer is only invalid to access + // after the next write, reset, or truncate. See + // https://github.com/golang/go/commit/5270b57e51b71f2b3410b601a9ba9f0a7a3d8441. + payload := buf.Bytes() // Test checksum. - checksum := chainhash.HashB(payload)[0:4] - if !bytes.Equal(checksum, hdr.checksum[:]) { - msg := fmt.Sprintf("payload checksum failed - header indicates %v, "+ - "but actual checksum is %v.", hdr.checksum, checksum) + payloadHash := chainhash.HashH(payload) + if !bytes.Equal(payloadHash[:4], checksum[:]) { + // Create heap copies to avoid leaking the originals in the + // fmt.Sprintf varargs. + payloadHash := payloadHash + checksum := checksum + msg := fmt.Sprintf("payload checksum failed - header indicates %x, "+ + "but actual checksum is %x.", checksum[:], payloadHash[:4]) return totalBytes, nil, nil, messageError(op, ErrPayloadChecksum, msg) } - // Unmarshal message. NOTE: This must be a *bytes.Buffer since the - // MsgVersion BtcDecode function requires it. - pr := bytes.NewBuffer(payload) - err = msg.BtcDecode(pr, pver) + // Unmarshal message using the unread payload in the buffer. + err = msg.BtcDecode(buf, pver) if err != nil { return totalBytes, nil, nil, err } + // Reject messages that did not consume the full payload. + if buf.Len() > 0 { + msg := fmt.Sprintf("message payload has %d unconsumed trailing "+ + "bytes", buf.Len()) + return totalBytes, nil, nil, messageError(op, ErrTrailingBytes, msg) + } + return totalBytes, msg, payload, nil } diff --git a/wire/message_test.go b/wire/message_test.go index 19db98edc..2901e060f 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -202,11 +202,10 @@ func TestReadMessageWireErrors(t *testing.T) { pver := ProtocolVersion dcrnet := MainNet - // Wire encoded bytes for main and testnet networks magic identifiers. + // Wire encoded bytes for testnet magic identifier. testNetBytes := makeHeader(TestNet3, "", 0, 0) - // Wire encoded bytes for a message that exceeds max overall message - // length. + // Wire encoded bytes for a message that exceeds max overall message length. mpl := uint32(MaxMessagePayload) exceedMaxPayloadBytes := makeHeader(dcrnet, "getaddr", mpl+1, 0) @@ -234,151 +233,122 @@ func TestReadMessageWireErrors(t *testing.T) { // contained in the message. Claim there is two, but don't provide // them. At the same time, forge the header fields so the message is // otherwise accurate. - badMessageBytes := makeHeader(dcrnet, "addr", 1, 0xeaadc31c) + badMessageBytes := makeHeader(dcrnet, "addr", 1, 0xab37af49) badMessageBytes = append(badMessageBytes, 0x2) - // Wire encoded bytes for a message which the header claims has 15k - // bytes of data to discard. - discardBytes := makeHeader(dcrnet, "bogus", 15*1024, 0) + // Wire encoded bytes for a message that is valid, but contains additional + // trailing bytes and header fields that are forged so the message is + // otherwise accurate. + payloadSize := uint32(len(testBlockBytes)) + trailingBytes := makeHeader(dcrnet, "block", payloadSize+1, 0xb5ec24b8) + trailingBytes = append(trailingBytes, testBlockBytes...) + trailingBytes = append(trailingBytes, 0x01) tests := []struct { - buf []byte // Wire encoding - pver uint32 // Protocol version for wire encoding - dcrnet CurrencyNet // Decred network for wire encoding - max int // Max size of fixed buffer to induce errors - readErr error // Expected read error - bytes int // Expected num bytes read - }{ - // Latest protocol version with intentional read errors. - - // Short header. [0] - { - []byte{}, - pver, - dcrnet, - 0, - io.EOF, - 0, - }, - - // Wrong network. Want MainNet, but giving TestNet. [1] - { - testNetBytes, - pver, - dcrnet, - len(testNetBytes), - &MessageError{}, - 24, - }, - - // Exceed max overall message payload length. [2] - { - exceedMaxPayloadBytes, - pver, - dcrnet, - len(exceedMaxPayloadBytes), - &MessageError{}, - 24, - }, - - // Invalid UTF-8 command. [3] - { - badCommandBytes, - pver, - dcrnet, - len(badCommandBytes), - &MessageError{}, - 24, - }, - - // Valid, but unsupported command. [4] - { - unsupportedCommandBytes, - pver, - dcrnet, - len(unsupportedCommandBytes), - &MessageError{}, - 24, - }, - - // Exceed max allowed payload for a message of a specific type. [5] - { - exceedTypePayloadBytes, - pver, - dcrnet, - len(exceedTypePayloadBytes), - &MessageError{}, - 24, - }, - - // Message with a payload shorter than the header indicates. [6] - { - shortPayloadBytes, - pver, - dcrnet, - len(shortPayloadBytes), - io.EOF, - 24, - }, - - // Message with a bad checksum. [7] - { - badChecksumBytes, - pver, - dcrnet, - len(badChecksumBytes), - &MessageError{}, - 26, - }, - - // Message with a valid header, but wrong format. [8] - { - badMessageBytes, - pver, - dcrnet, - len(badMessageBytes), - &MessageError{}, - 25, - }, - - // 15k bytes of data to discard. [9] - { - discardBytes, - pver, - dcrnet, - len(discardBytes), - &MessageError{}, - 24, - }, - } + name string // Test description + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + dcrnet CurrencyNet // Decred network for wire encoding + max int // Max size of fixed buffer to induce errors + err error // Expected read error + bytes int // Expected num bytes read + }{{ + name: "short header", + buf: nil, + pver: pver, + dcrnet: dcrnet, + max: 0, + err: io.EOF, + bytes: 0, + }, { + name: "wrong network, want mainnet, giving testnet", + buf: testNetBytes, + pver: pver, + dcrnet: dcrnet, + max: len(testNetBytes), + err: ErrWrongNetwork, + bytes: len(testNetBytes), + }, { + name: "exceed max overall message payload length", + buf: exceedMaxPayloadBytes, + pver: pver, + dcrnet: dcrnet, + max: len(exceedMaxPayloadBytes), + err: ErrPayloadTooLarge, + bytes: len(exceedMaxPayloadBytes), + }, { + name: "invalid utf-8 command", + buf: badCommandBytes, + pver: pver, + dcrnet: dcrnet, + max: len(badCommandBytes), + err: ErrMalformedCmd, + bytes: len(badCommandBytes), + }, { + name: "valid, but unsupported command", + buf: unsupportedCommandBytes, + pver: pver, + dcrnet: dcrnet, + max: len(unsupportedCommandBytes), + err: ErrUnknownCmd, + bytes: len(unsupportedCommandBytes), + }, { + name: "exceed max allowed payload for a message of a specific type", + buf: exceedTypePayloadBytes, + pver: pver, + dcrnet: dcrnet, + max: len(exceedTypePayloadBytes), + err: ErrPayloadTooLarge, + bytes: len(exceedTypePayloadBytes), + }, { + name: "payload shorter than the header indicates", + buf: shortPayloadBytes, + pver: pver, + dcrnet: dcrnet, + max: len(shortPayloadBytes), + err: io.EOF, + bytes: len(shortPayloadBytes), + }, { + name: "bad checksum", + buf: badChecksumBytes, + pver: pver, + dcrnet: dcrnet, + max: len(badChecksumBytes), + err: ErrPayloadChecksum, + bytes: len(badChecksumBytes), + }, { + name: "valid header, but wrong message body format", + buf: badMessageBytes, + pver: BatchedCFiltersV2Version - 1, + dcrnet: dcrnet, + max: len(badMessageBytes), + err: io.EOF, + bytes: len(badMessageBytes), + }, { + name: "valid header and message with extra trailing bytes", + buf: trailingBytes, + pver: pver, + dcrnet: dcrnet, + max: len(trailingBytes), + err: ErrTrailingBytes, + bytes: len(trailingBytes), + }} t.Logf("Running %d tests", len(tests)) - for i, test := range tests { + for _, test := range tests { // Decode from wire format. r := newFixedReader(test.max, test.buf) nr, _, _, err := ReadMessageN(r, test.pver, test.dcrnet) - if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) { - t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+ - "want: %T", i, err, err, test.readErr) + if !errors.Is(err, test.err) { + t.Errorf("%q: wrong error: got %v <%[2]T>, want: %v <%[3]T>", + test.name, err, test.err) continue } - // Ensure the number of bytes written match the expected value. + // Ensure the number of bytes read matches the expected value. if nr != test.bytes { - t.Errorf("ReadMessage #%d unexpected num bytes read - "+ - "got %d, want %d", i, nr, test.bytes) - } - - // For errors which are not of type MessageError, check them for - // equality. - var merr *MessageError - if !errors.As(err, &merr) { - if !errors.Is(err, test.readErr) { - t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+ - "want: %v <%T>", i, err, err, - test.readErr, test.readErr) - continue - } + t.Errorf("%q: unexpected num bytes read - got %d, want %d", + test.name, nr, test.bytes) } } } @@ -428,7 +398,7 @@ func TestWriteMessageWireErrors(t *testing.T) { // Force error in header write. {bogusMsg, pver, dcrnet, 0, io.ErrShortWrite, 0}, // Force error in payload write. - {bogusMsg, pver, dcrnet, 24, io.ErrShortWrite, 24}, + {bogusMsg, pver, dcrnet, 24, io.ErrShortWrite, 0}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/msgaddr.go b/wire/msgaddr.go index 0418a03ed..6b9f3585b 100644 --- a/wire/msgaddr.go +++ b/wire/msgaddr.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2015 The btcsuite developers -// Copyright (c) 2015-2020 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. diff --git a/wire/msgaddr_test.go b/wire/msgaddr_test.go index 97654fe5f..6b2f751e0 100644 --- a/wire/msgaddr_test.go +++ b/wire/msgaddr_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2020 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. diff --git a/wire/msgcfheaders.go b/wire/msgcfheaders.go index f7229f05e..24bb31158 100644 --- a/wire/msgcfheaders.go +++ b/wire/msgcfheaders.go @@ -1,6 +1,6 @@ // Copyright (c) 2017 The btcsuite developers // Copyright (c) 2017 The Lightning Network Developers -// Copyright (c) 2018-2024 The Decred developers +// Copyright (c) 2018-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -116,7 +116,7 @@ func (msg *MsgCFHeaders) BtcEncode(w io.Writer, pver uint32) error { } // Write filter type - err = binarySerializer.PutUint8(w, uint8(msg.FilterType)) + err = writeUint8(w, uint8(msg.FilterType)) if err != nil { return err } diff --git a/wire/msgcfilter.go b/wire/msgcfilter.go index 50f1e7b08..841d36383 100644 --- a/wire/msgcfilter.go +++ b/wire/msgcfilter.go @@ -80,7 +80,7 @@ func (msg *MsgCFilter) BtcEncode(w io.Writer, pver uint32) error { return err } - err = binarySerializer.PutUint8(w, uint8(msg.FilterType)) + err = writeUint8(w, uint8(msg.FilterType)) if err != nil { return err } diff --git a/wire/msgcfilter_test.go b/wire/msgcfilter_test.go index 7343531d5..ac482eba5 100644 --- a/wire/msgcfilter_test.go +++ b/wire/msgcfilter_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2020 The Decred developers +// Copyright (c) 2019-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -65,9 +65,6 @@ func TestCFilter(t *testing.T) { // Ensure encoding with max CF data per message returns no error. data = make([]byte, MaxCFilterDataSize) msg = NewMsgCFilter(blockHash, GCSFilterExtended, data) - if err != nil { - t.Fatalf("NewMsgCFilter: %v", err) - } var buf bytes.Buffer err = msg.BtcEncode(&buf, pver) if err != nil { diff --git a/wire/msgcftypes.go b/wire/msgcftypes.go index 1692f8c66..71dbaab95 100644 --- a/wire/msgcftypes.go +++ b/wire/msgcftypes.go @@ -92,7 +92,7 @@ func (msg *MsgCFTypes) BtcEncode(w io.Writer, pver uint32) error { } for i := range msg.SupportedFilters { - err = binarySerializer.PutUint8(w, uint8(msg.SupportedFilters[i])) + err = writeUint8(w, uint8(msg.SupportedFilters[i])) if err != nil { return err } diff --git a/wire/msggetcfheaders.go b/wire/msggetcfheaders.go index fc59ba327..a1d572687 100644 --- a/wire/msggetcfheaders.go +++ b/wire/msggetcfheaders.go @@ -1,6 +1,6 @@ // Copyright (c) 2017 The btcsuite developers // Copyright (c) 2017 The Lightning Network Developers -// Copyright (c) 2018-2024 The Decred developers +// Copyright (c) 2018-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -115,7 +115,7 @@ func (msg *MsgGetCFHeaders) BtcEncode(w io.Writer, pver uint32) error { return err } - return binarySerializer.PutUint8(w, uint8(msg.FilterType)) + return writeUint8(w, uint8(msg.FilterType)) } // Command returns the protocol command string for the message. This is part diff --git a/wire/msggetcfilter.go b/wire/msggetcfilter.go index 281c18593..e3cd2a03c 100644 --- a/wire/msggetcfilter.go +++ b/wire/msggetcfilter.go @@ -1,6 +1,6 @@ // Copyright (c) 2017 The btcsuite developers // Copyright (c) 2017 The Lightning Network Developers -// Copyright (c) 2018-2024 The Decred developers +// Copyright (c) 2018-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -54,7 +54,7 @@ func (msg *MsgGetCFilter) BtcEncode(w io.Writer, pver uint32) error { if err != nil { return err } - return binarySerializer.PutUint8(w, uint8(msg.FilterType)) + return writeUint8(w, uint8(msg.FilterType)) } // Command returns the protocol command string for the message. This is part diff --git a/wire/msginitstate_test.go b/wire/msginitstate_test.go index 7d8a3e478..93158c471 100644 --- a/wire/msginitstate_test.go +++ b/wire/msginitstate_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2024 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. diff --git a/wire/msgtx.go b/wire/msgtx.go index 1ada0b422..194420ed5 100644 --- a/wire/msgtx.go +++ b/wire/msgtx.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2023 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -12,6 +12,7 @@ import ( "strconv" "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/crypto/blake256" ) const ( @@ -238,13 +239,33 @@ func readScript(r io.Reader, pver uint32, maxAllowed uint32, fieldName string) ( return nil, messageError(op, ErrVarBytesTooLong, msg) } - b := scriptPool.Borrow(count) - _, err = io.ReadFull(r, b) - if err != nil { - scriptPool.Return(b) - return nil, err + switch r := r.(type) { + // Read the script bytes from the underlying buffer when called by + // ReadMessageN. This requires that the buffer will not be reset or + // truncated in the future (which is not a guarantee that can be made + // for any bytes.Buffer passed to BtcDecode). + case *wireBuffer: + b := r.Next(int(count)) + if count != 0 && len(b) == 0 { + return nil, io.EOF + } + if len(b) < int(count) { + return nil, io.ErrUnexpectedEOF + } + return b, nil + + // In all other cases, borrow a temporary buffer from the scriptPool + // freelist (before copying all scripts to a single contiguous + // allocation and returning buffers back to the pool). + default: + b := scriptPool.Borrow(count) + _, err = io.ReadFull(r, b) + if err != nil { + scriptPool.Return(b) + return nil, err + } + return b, nil } - return b, nil } // OutPoint defines a Decred data type that is used to track previous @@ -391,24 +412,28 @@ func (msg *MsgTx) serialize(serType TxSerializeType) ([]byte, error) { return buf.Bytes(), nil } -// mustSerialize returns the serialization of the transaction for the provided -// serialization type without modifying the original transaction. It will panic -// if any errors occur. -func (msg *MsgTx) mustSerialize(serType TxSerializeType) []byte { - serialized, err := msg.serialize(serType) +// mustHash returns the hash of the transaction for the provided +// serialization type without modifying the original transaction. +// It will panic if serialization fails. +func (msg *MsgTx) mustHash(hasher *blake256.Hasher256, serType TxSerializeType) chainhash.Hash { + // Shallow copy so the serialization type can be changed without + // modifying the original transaction. + mtxCopy := *msg + mtxCopy.SerType = serType + err := mtxCopy.Serialize(hasher) if err != nil { panic(fmt.Sprintf("MsgTx failed serializing for type %v", serType)) } - return serialized + return hasher.Sum256() } -// TxHash generates the hash for the transaction prefix. Since it does not -// contain any witness data, it is not malleable and therefore is stable for -// use in unconfirmed transaction chains. +// TxHash generates the BLAKE-256 hash for the transaction prefix. Since it +// does not contain any witness data, it is not malleable and therefore is +// stable for use in unconfirmed transaction chains. func (msg *MsgTx) TxHash() chainhash.Hash { // TxHash should always calculate a non-witnessed hash. - return chainhash.HashH(msg.mustSerialize(TxSerializeNoWitness)) + return msg.mustHash(blake256.NewHasher256(), TxSerializeNoWitness) } // CachedTxHash is equivalent to calling TxHash, however it caches the result so @@ -433,29 +458,29 @@ func (msg *MsgTx) RecacheTxHash() *chainhash.Hash { return msg.CachedHash } -// TxHashWitness generates the hash for the transaction witness. +// TxHashWitness generates the BLAKE-256 hash for the transaction witness. func (msg *MsgTx) TxHashWitness() chainhash.Hash { // TxHashWitness should always calculate a witnessed hash. - return chainhash.HashH(msg.mustSerialize(TxSerializeOnlyWitness)) + return msg.mustHash(blake256.NewHasher256(), TxSerializeOnlyWitness) } -// TxHashFull generates the hash for the transaction prefix || witness. It first -// obtains the hashes for both the transaction prefix and witness, then -// concatenates them and hashes the result. +// TxHashFull generates the hash for the transaction prefix || witness. This +// is the BLAKE-256 hash of the concatenation of the individual prefix and +// witness hashes (and not the hash of the full serialization). func (msg *MsgTx) TxHashFull() chainhash.Hash { - // Note that the inputs to the hashes, the serialized prefix and - // witness, have different serialized versions because the serialized - // encoding of the version includes the real transaction version in the - // lower 16 bits and the transaction serialization type in the upper 16 - // bits. The real transaction version (lower 16 bits) will be the same - // in both serializations. - concat := make([]byte, chainhash.HashSize*2) - prefixHash := msg.TxHash() - witnessHash := msg.TxHashWitness() - copy(concat[0:], prefixHash[:]) - copy(concat[chainhash.HashSize:], witnessHash[:]) - - return chainhash.HashH(concat) + // Even for a transaction that has neither prefix nor witness (and + // would otherwise hash to the same result), the prefix and witness + // hashes will still differ due to the serialization type being + // encoded into the upper 16 bits of the transaction version. + hasher := blake256.NewHasher256() + prefixHash := msg.mustHash(hasher, TxSerializeNoWitness) + hasher.Reset() + witnessHash := msg.mustHash(hasher, TxSerializeOnlyWitness) + hasher.Reset() + + hasher.WriteBytes(prefixHash[:]) + hasher.WriteBytes(witnessHash[:]) + return hasher.Sum256() } // Copy creates a deep copy of a transaction so that the original does not get @@ -664,12 +689,12 @@ func (msg *MsgTx) decodePrefix(r io.Reader, pver uint32) (uint64, error) { } // Locktime and expiry. - msg.LockTime, err = binarySerializer.Uint32(r, littleEndian) + err = readUint32LE(r, &msg.LockTime) if err != nil { return 0, err } - msg.Expiry, err = binarySerializer.Uint32(r, littleEndian) + err = readUint32LE(r, &msg.Expiry) if err != nil { return 0, err } @@ -770,19 +795,36 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32) error { // The serialized encoding of the version includes the real transaction // version in the lower 16 bits and the transaction serialization type // in the upper 16 bits. - version, err := binarySerializer.Uint32(r, littleEndian) + var version uint32 + err := readUint32LE(r, &version) if err != nil { return err } msg.Version = uint16(version & 0xffff) msg.SerType = TxSerializeType(version >> 16) + var copyScripts bool + switch r.(type) { + case *wireBuffer: + copyScripts = false + default: + copyScripts = true + + // Prevent caller-provided script slices from being returned to the free + // list. + msg.TxIn = nil + msg.TxOut = nil + } + // returnScriptBuffers is a closure that returns any script buffers that // were borrowed from the pool when there are any deserialization // errors. This is only valid to call before the final step which // replaces the scripts with the location in a contiguous buffer and // returns them. returnScriptBuffers := func() { + if !copyScripts { + return + } for _, txIn := range msg.TxIn { if txIn == nil || txIn.SignatureScript == nil { continue @@ -810,7 +852,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32) error { returnScriptBuffers() return err } - writeTxScriptsToMsgTx(msg, totalScriptSize, txSerType) + if copyScripts { + writeTxScriptsToMsgTx(msg, totalScriptSize, txSerType) + } case TxSerializeOnlyWitness: totalScriptSize, err := msg.decodeWitness(r, pver, false) @@ -818,7 +862,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32) error { returnScriptBuffers() return err } - writeTxScriptsToMsgTx(msg, totalScriptSize, txSerType) + if copyScripts { + writeTxScriptsToMsgTx(msg, totalScriptSize, txSerType) + } case TxSerializeFull: totalScriptSizeIns, err := msg.decodePrefix(r, pver) @@ -831,8 +877,10 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32) error { returnScriptBuffers() return err } - writeTxScriptsToMsgTx(msg, totalScriptSizeIns+ - totalScriptSizeOuts, txSerType) + if copyScripts { + writeTxScriptsToMsgTx(msg, totalScriptSizeIns+ + totalScriptSizeOuts, txSerType) + } default: return messageError(op, ErrUnknownTxType, "unsupported transaction type") @@ -892,12 +940,11 @@ func (msg *MsgTx) encodePrefix(w io.Writer, pver uint32) error { } } - err = binarySerializer.PutUint32(w, littleEndian, msg.LockTime) + err = writeUint32LE(w, msg.LockTime) if err != nil { return err } - - return binarySerializer.PutUint32(w, littleEndian, msg.Expiry) + return writeUint32LE(w, msg.Expiry) } // encodeWitness encodes a transaction witness into a writer. @@ -927,7 +974,7 @@ func (msg *MsgTx) BtcEncode(w io.Writer, pver uint32) error { // version in the lower 16 bits and the transaction serialization type // in the upper 16 bits. serializedVersion := uint32(msg.Version) | uint32(msg.SerType)<<16 - err := binarySerializer.PutUint32(w, littleEndian, serializedVersion) + err := writeUint32LE(w, serializedVersion) if err != nil { return err } @@ -1130,12 +1177,13 @@ func ReadOutPoint(r io.Reader, pver uint32, version uint16, op *OutPoint) error return err } - op.Index, err = binarySerializer.Uint32(r, littleEndian) + err = readUint32LE(r, &op.Index) if err != nil { return err } - tree, err := binarySerializer.Uint8(r) + var tree uint8 + err = readUint8(r, &tree) if err != nil { return err } @@ -1152,12 +1200,12 @@ func WriteOutPoint(w io.Writer, pver uint32, version uint16, op *OutPoint) error return err } - err = binarySerializer.PutUint32(w, littleEndian, op.Index) + err = writeUint32LE(w, op.Index) if err != nil { return err } - return binarySerializer.PutUint8(w, uint8(op.Tree)) + return writeUint8(w, uint8(op.Tree)) } // readTxInPrefix reads the next sequence of bytes from r as a transaction input @@ -1175,28 +1223,28 @@ func readTxInPrefix(r io.Reader, pver uint32, serType TxSerializeType, version u } // Sequence. - ti.Sequence, err = binarySerializer.Uint32(r, littleEndian) - return err + return readUint32LE(r, &ti.Sequence) } // readTxInWitness reads the next sequence of bytes from r as a transaction input // (TxIn) in the transaction witness. func readTxInWitness(r io.Reader, pver uint32, version uint16, ti *TxIn) error { // ValueIn. - valueIn, err := binarySerializer.Uint64(r, littleEndian) + var valueIn uint64 + err := readUint64LE(r, &valueIn) if err != nil { return err } ti.ValueIn = int64(valueIn) // BlockHeight. - ti.BlockHeight, err = binarySerializer.Uint32(r, littleEndian) + err = readUint32LE(r, &ti.BlockHeight) if err != nil { return err } // BlockIndex. - ti.BlockIndex, err = binarySerializer.Uint32(r, littleEndian) + err = readUint32LE(r, &ti.BlockIndex) if err != nil { return err } @@ -1215,26 +1263,26 @@ func writeTxInPrefix(w io.Writer, pver uint32, version uint16, ti *TxIn) error { return err } - return binarySerializer.PutUint32(w, littleEndian, ti.Sequence) + return writeUint32LE(w, ti.Sequence) } // writeTxInWitness encodes ti to the Decred protocol encoding for a transaction // input (TxIn) witness to w. func writeTxInWitness(w io.Writer, pver uint32, version uint16, ti *TxIn) error { // ValueIn. - err := binarySerializer.PutUint64(w, littleEndian, uint64(ti.ValueIn)) + err := writeUint64LE(w, uint64(ti.ValueIn)) if err != nil { return err } // BlockHeight. - err = binarySerializer.PutUint32(w, littleEndian, ti.BlockHeight) + err = writeUint32LE(w, ti.BlockHeight) if err != nil { return err } // BlockIndex. - err = binarySerializer.PutUint32(w, littleEndian, ti.BlockIndex) + err = writeUint32LE(w, ti.BlockIndex) if err != nil { return err } @@ -1246,13 +1294,14 @@ func writeTxInWitness(w io.Writer, pver uint32, version uint16, ti *TxIn) error // readTxOut reads the next sequence of bytes from r as a transaction output // (TxOut). func readTxOut(r io.Reader, pver uint32, version uint16, to *TxOut) error { - value, err := binarySerializer.Uint64(r, littleEndian) + var value uint64 + err := readUint64LE(r, &value) if err != nil { return err } to.Value = int64(value) - to.Version, err = binarySerializer.Uint16(r, littleEndian) + err = readUint16LE(r, &to.Version) if err != nil { return err } @@ -1265,12 +1314,12 @@ func readTxOut(r io.Reader, pver uint32, version uint16, to *TxOut) error { // writeTxOut encodes to into the Decred protocol encoding for a transaction // output (TxOut) to w. func writeTxOut(w io.Writer, pver uint32, version uint16, to *TxOut) error { - err := binarySerializer.PutUint64(w, littleEndian, uint64(to.Value)) + err := writeUint64LE(w, uint64(to.Value)) if err != nil { return err } - err = binarySerializer.PutUint16(w, littleEndian, to.Version) + err = writeUint16LE(w, to.Version) if err != nil { return err } diff --git a/wire/msgversion.go b/wire/msgversion.go index 9773634ea..1fbef37a5 100644 --- a/wire/msgversion.go +++ b/wire/msgversion.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2024 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -80,8 +80,15 @@ func (msg *MsgVersion) AddService(service ServiceFlag) { // This is part of the Message interface implementation. func (msg *MsgVersion) BtcDecode(r io.Reader, pver uint32) error { const op = "MsgVersion.BtcDecode" - buf, ok := r.(*bytes.Buffer) - if !ok { + // In addition to a *bytes.Buffer as described by the public + // documentation, the internal *wireBuffer type is also allowed. + var buf *bytes.Buffer + switch r := r.(type) { + case *wireBuffer: + buf = (*bytes.Buffer)(r) + case *bytes.Buffer: + buf = r + default: return messageError(op, ErrInvalidMsg, "reader is not a *bytes.Buffer") } @@ -157,13 +164,8 @@ func (msg *MsgVersion) BtcEncode(w io.Writer, pver uint32) error { return err } - var elems struct { - ts int64 - relayTx bool - } - elems.ts = msg.Timestamp.Unix() - - err = writeElements(w, &msg.ProtocolVersion, &msg.Services, &elems.ts) + err = writeElements(w, &msg.ProtocolVersion, &msg.Services, + (*int64Time)(&msg.Timestamp)) if err != nil { return err } @@ -193,8 +195,8 @@ func (msg *MsgVersion) BtcEncode(w io.Writer, pver uint32) error { return err } - elems.relayTx = !msg.DisableRelayTx - return writeElement(w, &elems.relayTx) + var relayTx = !msg.DisableRelayTx + return writeElement(w, &relayTx) } // Command returns the protocol command string for the message. This is part diff --git a/wire/netaddress.go b/wire/netaddress.go index 92ce4bae5..22aceb52c 100644 --- a/wire/netaddress.go +++ b/wire/netaddress.go @@ -1,12 +1,11 @@ // Copyright (c) 2013-2015 The btcsuite developers -// Copyright (c) 2015-2023 The Decred developers +// Copyright (c) 2015-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package wire import ( - "encoding/binary" "errors" "io" "net" @@ -118,7 +117,8 @@ func readNetAddress(r io.Reader, pver uint32, na *NetAddress, ts bool) error { return err } // Sigh. Decred protocol mixes little and big endian. - port, err := binarySerializer.Uint16(r, bigEndian) + var port uint16 + err = readUint16BE(r, &port) if err != nil { return err } @@ -162,5 +162,5 @@ func writeNetAddress(w io.Writer, pver uint32, na *NetAddress, ts bool) error { } // Sigh. Decred protocol mixes little and big endian. - return binary.Write(w, bigEndian, &na.Port) + return writeUint16BE(w, na.Port) }