From 5f09652e8bb83de71c304a580277ce8577c13138 Mon Sep 17 00:00:00 2001 From: Tom Titchener Date: Thu, 26 Jul 2018 08:36:40 -0400 Subject: [PATCH] decode discriminated unions --- xdr2/bench_test.go | 2 +- xdr2/decode.go | 122 +++++++++++++++++++++++++++---------------- xdr2/example_test.go | 2 +- 3 files changed, 79 insertions(+), 47 deletions(-) diff --git a/xdr2/bench_test.go b/xdr2/bench_test.go index 0fea4db..406d2c2 100644 --- a/xdr2/bench_test.go +++ b/xdr2/bench_test.go @@ -21,7 +21,7 @@ import ( "testing" "unsafe" - "github.com/davecgh/go-xdr/xdr2" + xdr "github.com/bluearchive/go-xdr/xdr2" ) // BenchmarkUnmarshal benchmarks the Unmarshal function by using a dummy diff --git a/xdr2/decode.go b/xdr2/decode.go index 494dae6..b5ae2ef 100644 --- a/xdr2/decode.go +++ b/xdr2/decode.go @@ -21,6 +21,8 @@ import ( "io" "math" "reflect" + "strconv" + "strings" "time" ) @@ -426,7 +428,7 @@ func (d *Decoder) decodeFixedArray(v reflect.Value, ignoreOpaque bool) (int, err // Decode each array element. var n int for i := 0; i < v.Len(); i++ { - n2, err := d.decode(v.Index(i)) + n2, _, err := d.decode(v.Index(i)) n += n2 if err != nil { return n, err @@ -485,7 +487,7 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool) (int, error) { // Decode each slice element. for i := 0; i < sliceLen; i++ { - n2, err := d.decode(v.Index(i)) + n2, _, err := d.decode(v.Index(i)) n += n2 if err != nil { return n, err @@ -505,10 +507,17 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool) (int, error) { // Reference: // RFC Section 4.14 - Structure // XDR encoded elements in the order of their declaration in the struct + +// RFC Section 4.15 - Discriminated Union +// XDR encoded elements in the order of their declaration in the struct func (d *Decoder) decodeStruct(v reflect.Value) (int, error) { var n int + var n2 int vt := v.Type() + discrval := -1 // value for scalar to interpret when discru is true + discrtarg := -1 // "xdr" tag value "unioncase=" converted from string for i := 0; i < v.NumField(); i++ { + // Skip unexported fields. vtf := vt.Field(i) if vtf.PkgPath != "" { @@ -553,12 +562,29 @@ func (d *Decoder) decodeStruct(v reflect.Value) (int, error) { } } + xdrtag := vtf.Tag.Get("xdr") + discru := false // "xdr" tag value "union" + if "union" == xdrtag { + discru = true + } else if strings.HasPrefix(xdrtag, "unioncase") { + vals := strings.Split(xdrtag, "=") + discrtarg, _ = strconv.Atoi(vals[1]) + } + + if discrval != -1 && discrval != discrtarg { + continue + } + // Decode each struct field. - n2, err := d.decode(vf) + v := 0 + n2, v, err = d.decode(vf) n += n2 if err != nil { return n, err } + if discru { + discrval = v + } } return n, nil @@ -598,14 +624,14 @@ func (d *Decoder) decodeMap(v reflect.Value) (int, error) { elemType := vt.Elem() for i := uint32(0); i < dataLen; i++ { key := reflect.New(keyType).Elem() - n2, err := d.decode(key) + n2, _, err := d.decode(key) n += n2 if err != nil { return n, err } val := reflect.New(elemType).Elem() - n2, err = d.decode(val) + n2, _, err = d.decode(val) n += n2 if err != nil { return n, err @@ -645,25 +671,28 @@ func (d *Decoder) decodeInterface(v reflect.Value) (int, error) { nil, nil) return 0, err } - return d.decode(ve) + n, _, err := d.decode(ve) + return n, err } // decode is the main workhorse for unmarshalling via reflection. It uses // the passed reflection value to choose the XDR primitives to decode from // the encapsulated reader. It is a recursive function, // so cyclic data structures are not supported and will result in an infinite -// loop. It returns the the number of bytes actually read. -func (d *Decoder) decode(v reflect.Value) (int, error) { +// loop. It returns the the number of bytes actually read, a possible +// discriminator value in case the previous field had the "union" value for +// the "xdr" tag, and the error status +func (d *Decoder) decode(v reflect.Value) (int, int, error) { if !v.IsValid() { msg := fmt.Sprintf("type '%s' is not valid", v.Kind().String()) err := unmarshalError("decode", ErrUnsupportedType, msg, nil, nil) - return 0, err + return 0, -1, err } // Indirect through pointers allocating them as needed. ve, err := d.indirect(v) if err != nil { - return 0, err + return 0, -1, err } // Handle time.Time values by decoding them as an RFC3339 formatted @@ -674,16 +703,16 @@ func (d *Decoder) decode(v reflect.Value) (int, error) { // Read the value as a string and parse it. timeString, n, err := d.DecodeString() if err != nil { - return n, err + return n, -1, err } ttv, err := time.Parse(time.RFC3339, timeString) if err != nil { err := unmarshalError("decode", ErrParseTime, err.Error(), timeString, err) - return n, err + return n, -1, err } ve.Set(reflect.ValueOf(ttv)) - return n, nil + return n, -1, nil } // Handle native Go types. @@ -691,113 +720,116 @@ func (d *Decoder) decode(v reflect.Value) (int, error) { case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: i, n, err := d.DecodeInt() if err != nil { - return n, err + return n, int(i), err } if ve.OverflowInt(int64(i)) { msg := fmt.Sprintf("signed integer too large to fit '%s'", ve.Kind().String()) err = unmarshalError("decode", ErrOverflow, msg, i, nil) - return n, err + return n, -1, err } ve.SetInt(int64(i)) - return n, nil + return n, int(i), nil case reflect.Int64: i, n, err := d.DecodeHyper() if err != nil { - return n, err + return n, -1, err } ve.SetInt(i) - return n, nil + return n, -1, nil case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint: ui, n, err := d.DecodeUint() if err != nil { - return n, err + return n, int(n), err } if ve.OverflowUint(uint64(ui)) { msg := fmt.Sprintf("unsigned integer too large to fit '%s'", ve.Kind().String()) err = unmarshalError("decode", ErrOverflow, msg, ui, nil) - return n, err + return n, -1, err } ve.SetUint(uint64(ui)) - return n, nil + return n, int(ui), nil case reflect.Uint64: ui, n, err := d.DecodeUhyper() if err != nil { - return n, err + return n, -1, err } ve.SetUint(ui) - return n, nil - + return n, -1, nil case reflect.Bool: b, n, err := d.DecodeBool() if err != nil { - return n, err + return n, -1, err + } + v := 0 + if b { + v = 1 } ve.SetBool(b) - return n, nil + return n, v, nil case reflect.Float32: f, n, err := d.DecodeFloat() if err != nil { - return n, err + return n, -1, err } ve.SetFloat(float64(f)) - return n, nil + return n, -1, nil case reflect.Float64: f, n, err := d.DecodeDouble() if err != nil { - return n, err + return n, -1, err } ve.SetFloat(f) - return n, nil + return n, -1, nil case reflect.String: s, n, err := d.DecodeString() if err != nil { - return n, err + return n, -1, err } ve.SetString(s) - return n, nil + return n, -1, nil case reflect.Array: n, err := d.decodeFixedArray(ve, false) if err != nil { - return n, err + return n, -1, err } - return n, nil + return n, -1, nil case reflect.Slice: n, err := d.decodeArray(ve, false) if err != nil { - return n, err + return n, -1, err } - return n, nil + return n, -1, nil case reflect.Struct: n, err := d.decodeStruct(ve) if err != nil { - return n, err + return n, -1, err } - return n, nil + return n, -1, nil case reflect.Map: n, err := d.decodeMap(ve) if err != nil { - return n, err + return n, -1, err } - return n, nil + return n, -1, nil case reflect.Interface: n, err := d.decodeInterface(ve) if err != nil { - return n, err + return n, -1, err } - return n, nil + return n, -1, nil } // The only unhandled types left are unsupported. At the time of this @@ -805,7 +837,7 @@ func (d *Decoder) decode(v reflect.Value) (int, error) { // reflect.Uintptr and reflect.UnsafePointer. msg := fmt.Sprintf("unsupported Go type '%s'", ve.Kind().String()) err = unmarshalError("decode", ErrUnsupportedType, msg, nil, nil) - return 0, err + return 0, -1, err } // indirect dereferences pointers allocating them as needed until it reaches @@ -855,8 +887,8 @@ func (d *Decoder) Decode(v interface{}) (int, error) { err := unmarshalError("Unmarshal", ErrNotSettable, msg, nil, nil) return 0, err } - - return d.decode(vv) + n, _, err := d.decode(vv) + return n, err } // NewDecoder returns a Decoder that can be used to manually decode XDR data diff --git a/xdr2/example_test.go b/xdr2/example_test.go index 1272862..f60039f 100644 --- a/xdr2/example_test.go +++ b/xdr2/example_test.go @@ -20,7 +20,7 @@ import ( "bytes" "fmt" - "github.com/davecgh/go-xdr/xdr2" + xdr "github.com/bluearchive/go-xdr/xdr2" ) // This example demonstrates how to use Marshal to automatically XDR encode