Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xdr2/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"testing"
"unsafe"

"github.com/davecgh/go-xdr/xdr2"
xdr "github.com/bluearchive/go-xdr/xdr2"
)

// BenchmarkUnmarshal benchmarks the Unmarshal function by using a dummy
Expand Down
122 changes: 77 additions & 45 deletions xdr2/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"io"
"math"
"reflect"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -426,7 +428,7 @@ func (d *Decoder) decodeFixedArray(v reflect.Value, ignoreOpaque bool) (int, err
// Decode each array element.
var n int
for i := 0; i < v.Len(); i++ {
n2, err := d.decode(v.Index(i))
n2, _, err := d.decode(v.Index(i))
n += n2
if err != nil {
return n, err
Expand Down Expand Up @@ -485,7 +487,7 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool) (int, error) {

// Decode each slice element.
for i := 0; i < sliceLen; i++ {
n2, err := d.decode(v.Index(i))
n2, _, err := d.decode(v.Index(i))
n += n2
if err != nil {
return n, err
Expand All @@ -505,10 +507,17 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool) (int, error) {
// Reference:
// RFC Section 4.14 - Structure
// XDR encoded elements in the order of their declaration in the struct

// RFC Section 4.15 - Discriminated Union
// XDR encoded elements in the order of their declaration in the struct
func (d *Decoder) decodeStruct(v reflect.Value) (int, error) {
var n int
var n2 int
vt := v.Type()
discrval := -1 // value for scalar to interpret when discru is true
discrtarg := -1 // "xdr" tag value "unioncase=<num>" <num> converted from string
for i := 0; i < v.NumField(); i++ {

// Skip unexported fields.
vtf := vt.Field(i)
if vtf.PkgPath != "" {
Expand Down Expand Up @@ -553,12 +562,29 @@ func (d *Decoder) decodeStruct(v reflect.Value) (int, error) {
}
}

xdrtag := vtf.Tag.Get("xdr")
discru := false // "xdr" tag value "union"
if "union" == xdrtag {
discru = true
} else if strings.HasPrefix(xdrtag, "unioncase") {
vals := strings.Split(xdrtag, "=")
discrtarg, _ = strconv.Atoi(vals[1])
}

if discrval != -1 && discrval != discrtarg {
continue
}

// Decode each struct field.
n2, err := d.decode(vf)
v := 0
n2, v, err = d.decode(vf)
n += n2
if err != nil {
return n, err
}
if discru {
discrval = v
}
}

return n, nil
Expand Down Expand Up @@ -598,14 +624,14 @@ func (d *Decoder) decodeMap(v reflect.Value) (int, error) {
elemType := vt.Elem()
for i := uint32(0); i < dataLen; i++ {
key := reflect.New(keyType).Elem()
n2, err := d.decode(key)
n2, _, err := d.decode(key)
n += n2
if err != nil {
return n, err
}

val := reflect.New(elemType).Elem()
n2, err = d.decode(val)
n2, _, err = d.decode(val)
n += n2
if err != nil {
return n, err
Expand Down Expand Up @@ -645,25 +671,28 @@ func (d *Decoder) decodeInterface(v reflect.Value) (int, error) {
nil, nil)
return 0, err
}
return d.decode(ve)
n, _, err := d.decode(ve)
return n, err
}

// decode is the main workhorse for unmarshalling via reflection. It uses
// the passed reflection value to choose the XDR primitives to decode from
// the encapsulated reader. It is a recursive function,
// so cyclic data structures are not supported and will result in an infinite
// loop. It returns the the number of bytes actually read.
func (d *Decoder) decode(v reflect.Value) (int, error) {
// loop. It returns the the number of bytes actually read, a possible
// discriminator value in case the previous field had the "union" value for
// the "xdr" tag, and the error status
func (d *Decoder) decode(v reflect.Value) (int, int, error) {
if !v.IsValid() {
msg := fmt.Sprintf("type '%s' is not valid", v.Kind().String())
err := unmarshalError("decode", ErrUnsupportedType, msg, nil, nil)
return 0, err
return 0, -1, err
}

// Indirect through pointers allocating them as needed.
ve, err := d.indirect(v)
if err != nil {
return 0, err
return 0, -1, err
}

// Handle time.Time values by decoding them as an RFC3339 formatted
Expand All @@ -674,138 +703,141 @@ func (d *Decoder) decode(v reflect.Value) (int, error) {
// Read the value as a string and parse it.
timeString, n, err := d.DecodeString()
if err != nil {
return n, err
return n, -1, err
}
ttv, err := time.Parse(time.RFC3339, timeString)
if err != nil {
err := unmarshalError("decode", ErrParseTime,
err.Error(), timeString, err)
return n, err
return n, -1, err
}
ve.Set(reflect.ValueOf(ttv))
return n, nil
return n, -1, nil
}

// Handle native Go types.
switch ve.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
i, n, err := d.DecodeInt()
if err != nil {
return n, err
return n, int(i), err
}
if ve.OverflowInt(int64(i)) {
msg := fmt.Sprintf("signed integer too large to fit '%s'",
ve.Kind().String())
err = unmarshalError("decode", ErrOverflow, msg, i, nil)
return n, err
return n, -1, err
}
ve.SetInt(int64(i))
return n, nil
return n, int(i), nil

case reflect.Int64:
i, n, err := d.DecodeHyper()
if err != nil {
return n, err
return n, -1, err
}
ve.SetInt(i)
return n, nil
return n, -1, nil

case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint:
ui, n, err := d.DecodeUint()
if err != nil {
return n, err
return n, int(n), err
}
if ve.OverflowUint(uint64(ui)) {
msg := fmt.Sprintf("unsigned integer too large to fit '%s'",
ve.Kind().String())
err = unmarshalError("decode", ErrOverflow, msg, ui, nil)
return n, err
return n, -1, err
}
ve.SetUint(uint64(ui))
return n, nil
return n, int(ui), nil

case reflect.Uint64:
ui, n, err := d.DecodeUhyper()
if err != nil {
return n, err
return n, -1, err
}
ve.SetUint(ui)
return n, nil

return n, -1, nil
case reflect.Bool:
b, n, err := d.DecodeBool()
if err != nil {
return n, err
return n, -1, err
}
v := 0
if b {
v = 1
}
ve.SetBool(b)
return n, nil
return n, v, nil

case reflect.Float32:
f, n, err := d.DecodeFloat()
if err != nil {
return n, err
return n, -1, err
}
ve.SetFloat(float64(f))
return n, nil
return n, -1, nil

case reflect.Float64:
f, n, err := d.DecodeDouble()
if err != nil {
return n, err
return n, -1, err
}
ve.SetFloat(f)
return n, nil
return n, -1, nil

case reflect.String:
s, n, err := d.DecodeString()
if err != nil {
return n, err
return n, -1, err
}
ve.SetString(s)
return n, nil
return n, -1, nil

case reflect.Array:
n, err := d.decodeFixedArray(ve, false)
if err != nil {
return n, err
return n, -1, err
}
return n, nil
return n, -1, nil

case reflect.Slice:
n, err := d.decodeArray(ve, false)
if err != nil {
return n, err
return n, -1, err
}
return n, nil
return n, -1, nil

case reflect.Struct:
n, err := d.decodeStruct(ve)
if err != nil {
return n, err
return n, -1, err
}
return n, nil
return n, -1, nil

case reflect.Map:
n, err := d.decodeMap(ve)
if err != nil {
return n, err
return n, -1, err
}
return n, nil
return n, -1, nil

case reflect.Interface:
n, err := d.decodeInterface(ve)
if err != nil {
return n, err
return n, -1, err
}
return n, nil
return n, -1, nil
}

// The only unhandled types left are unsupported. At the time of this
// writing the only remaining unsupported types that exist are
// reflect.Uintptr and reflect.UnsafePointer.
msg := fmt.Sprintf("unsupported Go type '%s'", ve.Kind().String())
err = unmarshalError("decode", ErrUnsupportedType, msg, nil, nil)
return 0, err
return 0, -1, err
}

// indirect dereferences pointers allocating them as needed until it reaches
Expand Down Expand Up @@ -855,8 +887,8 @@ func (d *Decoder) Decode(v interface{}) (int, error) {
err := unmarshalError("Unmarshal", ErrNotSettable, msg, nil, nil)
return 0, err
}

return d.decode(vv)
n, _, err := d.decode(vv)
return n, err
}

// NewDecoder returns a Decoder that can be used to manually decode XDR data
Expand Down
2 changes: 1 addition & 1 deletion xdr2/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"bytes"
"fmt"

"github.com/davecgh/go-xdr/xdr2"
xdr "github.com/bluearchive/go-xdr/xdr2"
)

// This example demonstrates how to use Marshal to automatically XDR encode
Expand Down