diff --git a/wire/common_test.go b/wire/common_test.go index ede134b27..e91391592 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -7,11 +7,15 @@ package wire import ( "bytes" + "encoding/binary" + "encoding/hex" "errors" "io" "reflect" "strings" "testing" + "testing/iotest" + "time" "github.com/davecgh/go-spew/spew" "github.com/decred/dcrd/chaincfg/chainhash" @@ -52,6 +56,18 @@ func (r *fakeRandReader) Read(p []byte) (int, error) { return n, r.err } +// hexToBytes converts the passed hex string into bytes and will panic if there +// is an error. This is only provided for the hard-coded constants so errors in +// the source code can be detected. It will only (and must only) be called with +// hard-coded values. +func hexToBytes(s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + panic("invalid hex in source file: " + s) + } + return b +} + func newInt32(v int32) *int32 { return &v } @@ -89,38 +105,31 @@ func newCurrencyNet(v CurrencyNet) *CurrencyNet { // type assertions to avoid reflection when possible. func TestElementWire(t *testing.T) { type writeElementReflect int32 + newUint8 := func(v uint8) *uint8 { return &v } + newUint16 := func(v uint16) *uint16 { return &v } + newInt64Time := func(v time.Time) *int64Time { return (*int64Time)(&v) } + newUint64Time := func(v time.Time) *uint64Time { return (*uint64Time)(&v) } tests := []struct { in interface{} // Value to encode buf []byte // Wire encoding }{ - {newInt32(1), []byte{0x01, 0x00, 0x00, 0x00}}, - {newUint32(256), []byte{0x00, 0x01, 0x00, 0x00}}, - { - newInt64(65536), - []byte{0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - { - newUint64(4294967296), - []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, - }, - { - newBool(true), - []byte{0x01}, - }, - { - newBool(false), - []byte{0x00}, - }, + {newUint8(240), hexToBytes("f0")}, + {newUint16(61423), hexToBytes("efef")}, + {newInt32(1), hexToBytes("01000000")}, + {newUint32(256), hexToBytes("00010000")}, + {newInt64(65536), hexToBytes("0000010000000000")}, + {newUint64(4294967296), hexToBytes("0000000001000000")}, + {newBool(true), hexToBytes("01")}, + {newBool(false), hexToBytes("00")}, + {newInt64Time(time.Unix(1772075804, 0)), hexToBytes("1cbb9f6900000000")}, + {newUint64Time(time.Unix(1772075804, 0)), hexToBytes("1cbb9f6900000000")}, { &[16]byte{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, }, - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, - }, + hexToBytes("0102030405060708090a0b0c0d0e0f10"), }, { (*chainhash.Hash)(&[chainhash.HashSize]byte{ // Make go vet happy. @@ -129,30 +138,14 @@ func TestElementWire(t *testing.T) { 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, }), - []byte{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, - 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, - 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, - }, - }, - { - newServiceFlag(SFNodeNetwork), - []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - { - newInvType(InvTypeTx), - []byte{0x01, 0x00, 0x00, 0x00}, - }, - { - newCurrencyNet(MainNet), - []byte{0xf9, 0x00, 0xb4, 0xd9}, + hexToBytes("0102030405060708090a0b0c0d0e0f101112131415161718191a" + + "1b1c1d1e1f20"), }, + {newServiceFlag(SFNodeNetwork), hexToBytes("0100000000000000")}, + {newInvType(InvTypeTx), hexToBytes("01000000")}, + {newCurrencyNet(MainNet), hexToBytes("f900b4d9")}, // Type not supported by the "fast" path and requires reflection. - { - writeElementReflect(1), - []byte{0x01, 0x00, 0x00, 0x00}, - }, + {writeElementReflect(1), hexToBytes("01000000")}, } t.Logf("Running %d tests", len(tests)) @@ -190,6 +183,27 @@ func TestElementWire(t *testing.T) { spew.Sdump(ival), spew.Sdump(test.in)) continue } + + // Read from wire format again, but this time with a one byte reader. + obr := iotest.OneByteReader(bytes.NewReader(test.buf)) + val = test.in + if reflect.ValueOf(test.in).Kind() != reflect.Ptr { + val = reflect.New(reflect.TypeOf(test.in)).Interface() + } + err = readElement(obr, val) + if err != nil { + t.Errorf("readElement #%d error %v", i, err) + continue + } + ival = val + if reflect.ValueOf(test.in).Kind() != reflect.Ptr { + ival = reflect.Indirect(reflect.ValueOf(val)).Interface() + } + if !reflect.DeepEqual(ival, test.in) { + t.Errorf("readElement #%d\n got: %s want: %s", i, spew.Sdump(ival), + spew.Sdump(test.in)) + continue + } } } @@ -261,6 +275,79 @@ func TestElementWireErrors(t *testing.T) { } } +// TestShortReads ensures that all short reads work as expected with the various +// supported readers and a couple of readers for the default path including a +// one byte reader. +func TestShortReads(t *testing.T) { + var buf bytes.Buffer + buf.WriteByte(0x05) + binary.Write(&buf, binary.LittleEndian, uint16(61355)) + binary.Write(&buf, binary.BigEndian, uint16(61355)) + binary.Write(&buf, binary.LittleEndian, uint32(16777216)) + binary.Write(&buf, binary.LittleEndian, uint64(8589934592)) + testWithReader := func(r io.Reader) { + t.Helper() + + // Ensure readUint8 produces the expected value with no errors. + var u8 uint8 + err := readUint8(r, &u8) + if err != nil { + t.Fatalf("%T: readUint8 err: %v", r, err) + } + if u8 != 5 { + t.Fatalf("%T: readUint8 val: got %v, want %v", r, u8, 5) + } + + // Ensure readUint16LE produces the expected value with no errors. + var u16 uint16 + err = readUint16LE(r, &u16) + if err != nil { + t.Fatalf("%T: readUint16LE err: %v", r, err) + } + if u16 != 61355 { + t.Fatalf("%T: readUint16LE val: got %v, want %v", r, u16, + 61355) + } + + // Ensure readUint16BE produces the expected value with no errors. + err = readUint16BE(r, &u16) + if err != nil { + t.Fatalf("%T: readUint16BE err: %v", r, err) + } + if u16 != 61355 { + t.Fatalf("%T: readUint16BE val: got %v, want %v", r, u16, + 61355) + } + + // Ensure readUint32LE produces the expected value with no errors. + var u32 uint32 + err = readUint32LE(r, &u32) + if err != nil { + t.Fatalf("%T: readUint32LE err: %v", r, err) + } + if u32 != 16777216 { + t.Fatalf("%T: readUint32LE val: got %v, want %v", r, u32, + 16777216) + } + + // Ensure readUint64LE produces the expected value with no errors. + var u64 uint64 + err = readUint64LE(r, &u64) + if err != nil { + t.Fatalf("%T: readUint64LE err: %v", r, err) + } + if u64 != 8589934592 { + t.Fatalf("%T: readUint64LE val: got %v, want %v", r, u64, + 8589934592) + } + } + testWithReader(bytes.NewBuffer(buf.Bytes())) + testWithReader((*wireBuffer)(bytes.NewBuffer(buf.Bytes()))) + testWithReader(bytes.NewReader(buf.Bytes())) + testWithReader(io.LimitReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))) + testWithReader(iotest.OneByteReader(&buf)) +} + // TestVarIntWire tests wire encode and decode for variable length integers. func TestVarIntWire(t *testing.T) { pver := ProtocolVersion