Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 131 additions & 44 deletions wire/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ package wire

import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"io"
"reflect"
"strings"
"testing"
"testing/iotest"
"time"

"github.com/davecgh/go-spew/spew"
"github.com/decred/dcrd/chaincfg/chainhash"
Expand Down Expand Up @@ -52,6 +56,18 @@ func (r *fakeRandReader) Read(p []byte) (int, error) {
return n, r.err
}

// hexToBytes converts the passed hex string into bytes and will panic if there
// is an error. This is only provided for the hard-coded constants so errors in
// the source code can be detected. It will only (and must only) be called with
// hard-coded values.
func hexToBytes(s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
panic("invalid hex in source file: " + s)
}
return b
}

func newInt32(v int32) *int32 {
return &v
}
Expand Down Expand Up @@ -89,38 +105,31 @@ func newCurrencyNet(v CurrencyNet) *CurrencyNet {
// type assertions to avoid reflection when possible.
func TestElementWire(t *testing.T) {
type writeElementReflect int32
newUint8 := func(v uint8) *uint8 { return &v }
newUint16 := func(v uint16) *uint16 { return &v }
newInt64Time := func(v time.Time) *int64Time { return (*int64Time)(&v) }
newUint64Time := func(v time.Time) *uint64Time { return (*uint64Time)(&v) }

tests := []struct {
in interface{} // Value to encode
buf []byte // Wire encoding
}{
{newInt32(1), []byte{0x01, 0x00, 0x00, 0x00}},
{newUint32(256), []byte{0x00, 0x01, 0x00, 0x00}},
{
newInt64(65536),
[]byte{0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
newUint64(4294967296),
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
},
{
newBool(true),
[]byte{0x01},
},
{
newBool(false),
[]byte{0x00},
},
{newUint8(240), hexToBytes("f0")},
{newUint16(61423), hexToBytes("efef")},
{newInt32(1), hexToBytes("01000000")},
{newUint32(256), hexToBytes("00010000")},
{newInt64(65536), hexToBytes("0000010000000000")},
{newUint64(4294967296), hexToBytes("0000000001000000")},
{newBool(true), hexToBytes("01")},
{newBool(false), hexToBytes("00")},
{newInt64Time(time.Unix(1772075804, 0)), hexToBytes("1cbb9f6900000000")},
{newUint64Time(time.Unix(1772075804, 0)), hexToBytes("1cbb9f6900000000")},
{
&[16]byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
},
[]byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
},
hexToBytes("0102030405060708090a0b0c0d0e0f10"),
},
{
(*chainhash.Hash)(&[chainhash.HashSize]byte{ // Make go vet happy.
Expand All @@ -129,30 +138,14 @@ func TestElementWire(t *testing.T) {
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
}),
[]byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
},
},
{
newServiceFlag(SFNodeNetwork),
[]byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
newInvType(InvTypeTx),
[]byte{0x01, 0x00, 0x00, 0x00},
},
{
newCurrencyNet(MainNet),
[]byte{0xf9, 0x00, 0xb4, 0xd9},
hexToBytes("0102030405060708090a0b0c0d0e0f101112131415161718191a" +
"1b1c1d1e1f20"),
},
{newServiceFlag(SFNodeNetwork), hexToBytes("0100000000000000")},
{newInvType(InvTypeTx), hexToBytes("01000000")},
{newCurrencyNet(MainNet), hexToBytes("f900b4d9")},
// Type not supported by the "fast" path and requires reflection.
{
writeElementReflect(1),
[]byte{0x01, 0x00, 0x00, 0x00},
},
{writeElementReflect(1), hexToBytes("01000000")},
}

t.Logf("Running %d tests", len(tests))
Expand Down Expand Up @@ -190,6 +183,27 @@ func TestElementWire(t *testing.T) {
spew.Sdump(ival), spew.Sdump(test.in))
continue
}

// Read from wire format again, but this time with a one byte reader.
obr := iotest.OneByteReader(bytes.NewReader(test.buf))
val = test.in
if reflect.ValueOf(test.in).Kind() != reflect.Ptr {
val = reflect.New(reflect.TypeOf(test.in)).Interface()
}
err = readElement(obr, val)
if err != nil {
t.Errorf("readElement #%d error %v", i, err)
continue
}
ival = val
if reflect.ValueOf(test.in).Kind() != reflect.Ptr {
ival = reflect.Indirect(reflect.ValueOf(val)).Interface()
}
if !reflect.DeepEqual(ival, test.in) {
t.Errorf("readElement #%d\n got: %s want: %s", i, spew.Sdump(ival),
spew.Sdump(test.in))
continue
}
}
}

Expand Down Expand Up @@ -261,6 +275,79 @@ func TestElementWireErrors(t *testing.T) {
}
}

// TestShortReads ensures that all short reads work as expected with the various
// supported readers and a couple of readers for the default path including a
// one byte reader.
func TestShortReads(t *testing.T) {
var buf bytes.Buffer
buf.WriteByte(0x05)
binary.Write(&buf, binary.LittleEndian, uint16(61355))
binary.Write(&buf, binary.BigEndian, uint16(61355))
binary.Write(&buf, binary.LittleEndian, uint32(16777216))
binary.Write(&buf, binary.LittleEndian, uint64(8589934592))
testWithReader := func(r io.Reader) {
t.Helper()

// Ensure readUint8 produces the expected value with no errors.
var u8 uint8
err := readUint8(r, &u8)
if err != nil {
t.Fatalf("%T: readUint8 err: %v", r, err)
}
if u8 != 5 {
t.Fatalf("%T: readUint8 val: got %v, want %v", r, u8, 5)
}

// Ensure readUint16LE produces the expected value with no errors.
var u16 uint16
err = readUint16LE(r, &u16)
if err != nil {
t.Fatalf("%T: readUint16LE err: %v", r, err)
}
if u16 != 61355 {
t.Fatalf("%T: readUint16LE val: got %v, want %v", r, u16,
61355)
}

// Ensure readUint16BE produces the expected value with no errors.
err = readUint16BE(r, &u16)
if err != nil {
t.Fatalf("%T: readUint16BE err: %v", r, err)
}
if u16 != 61355 {
t.Fatalf("%T: readUint16BE val: got %v, want %v", r, u16,
61355)
}

// Ensure readUint32LE produces the expected value with no errors.
var u32 uint32
err = readUint32LE(r, &u32)
if err != nil {
t.Fatalf("%T: readUint32LE err: %v", r, err)
}
if u32 != 16777216 {
t.Fatalf("%T: readUint32LE val: got %v, want %v", r, u32,
16777216)
}

// Ensure readUint64LE produces the expected value with no errors.
var u64 uint64
err = readUint64LE(r, &u64)
if err != nil {
t.Fatalf("%T: readUint64LE err: %v", r, err)
}
if u64 != 8589934592 {
t.Fatalf("%T: readUint64LE val: got %v, want %v", r, u64,
8589934592)
}
}
testWithReader(bytes.NewBuffer(buf.Bytes()))
testWithReader((*wireBuffer)(bytes.NewBuffer(buf.Bytes())))
testWithReader(bytes.NewReader(buf.Bytes()))
testWithReader(io.LimitReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())))
testWithReader(iotest.OneByteReader(&buf))
}

// TestVarIntWire tests wire encode and decode for variable length integers.
func TestVarIntWire(t *testing.T) {
pver := ProtocolVersion
Expand Down