Skip to content

Commit

Permalink
Fix protocol test generation for streaming types
Browse files Browse the repository at this point in the history
Fixes the generation of protocol tests with streaming input and output
members. Refactors the comparison of structure values to use a utility
that flexibly compares the io.Reader values. Allowing the underlying
io.Reader implementations to be different types, but compare as equal if
the contents of the readers are the same.
  • Loading branch information
jasdel committed Jul 20, 2020
1 parent 403ab2e commit 90381e4
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import software.amazon.smithy.model.shapes.ShapeType;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.EnumTrait;
import software.amazon.smithy.model.traits.StreamingTrait;

/**
* Generates a shape type declaration based on the parameters provided.
Expand Down Expand Up @@ -169,27 +170,42 @@ protected void mapDeclShapeValue(GoWriter writer, MapShape shape, Runnable inner
*/
protected void scalarWrapShapeValue(GoWriter writer, Shape shape, Runnable inner) {
boolean withPtrImport = true;
String closing = ")";

switch (shape.getType()) {
case BOOLEAN:
writer.writeInline("ptr.Bool(");
break;

case BLOB:
writer.writeInline("[]byte(");
if (shape.hasTrait(StreamingTrait.class)) {
writer.addUseImports(SmithyGoDependency.SMITHY_IO);
writer.addUseImports(SmithyGoDependency.BYTES);
writer.writeInline("smithyio.ReadSeekNopCloser{ReadSeeker: bytes.NewReader([]byte(");
closing += ")}";
} else {
writer.writeInline("[]byte(");
}
withPtrImport = false;
break;

case STRING:
// Enum are not pointers, but string alias values
if (shape.hasTrait(EnumTrait.class)) {
if (shape.hasTrait(StreamingTrait.class)) {
writer.addUseImports(SmithyGoDependency.SMITHY_IO);
writer.addUseImports(SmithyGoDependency.STRINGS);
writer.writeInline("smithyio.ReadSeekNopCloser{ReadSeeker: strings.NewReader(");
closing += "}";

} else if (shape.hasTrait(EnumTrait.class)) {
Symbol enumSymbol = symbolProvider.toSymbol(shape);
writer.writeInline("$T(", enumSymbol);
withPtrImport = false;
break;

} else {
writer.writeInline("ptr.String(");
}

writer.writeInline("ptr.String(");
break;

case TIMESTAMP:
Expand Down Expand Up @@ -232,8 +248,9 @@ protected void scalarWrapShapeValue(GoWriter writer, Shape shape, Runnable inner
if (withPtrImport) {
writer.addUseImports(SmithyGoDependency.SMITHY_PTR);
}

inner.run();
writer.writeInline(")");
writer.writeInline(closing);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,17 @@ protected void writeAssertScalarEqual(GoWriter writer, String expect, String act
protected void writeAssertComplexEqual(
GoWriter writer, String expect, String actual, String[] ignoreTypes
) {
writer.addUseImports(SmithyGoDependency.GO_CMP);
writer.addUseImports(SmithyGoDependency.SMITHY_TESTING);
writer.addUseImports(SmithyGoDependency.GO_CMP_OPTIONS);
writer.writeInline("if diff := cmp.Diff($L, $L, cmpopts.IgnoreUnexported(", expect, actual);

writer.writeInline("if err := smithytesting.CompareValues($L, $L, cmpopts.IgnoreUnexported(", expect, actual);

for (String ignoreType : ignoreTypes) {
writer.write("$L,", ignoreType);
}

writer.writeInline(")); len(diff) != 0 {");
writer.write(" t.Errorf(\"expect $L value match:\\n%s\", diff)", expect);
writer.writeInline(")); err != nil {");
writer.write(" t.Errorf(\"expect $L value match:\\n%v\", err)", expect);
writer.write("}");
}

Expand Down
16 changes: 16 additions & 0 deletions io/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io

import (
"io"
)

// ReadSeekNopCloser wraps an io.ReadSeeker with an additional Close method
// that does nothing.
type ReadSeekNopCloser struct {
io.ReadSeeker
}

// Close does nothing.
func (ReadSeekNopCloser) Close() error {
return nil
}
95 changes: 95 additions & 0 deletions testing/struct.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package testing

import (
"bytes"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"

"github.com/google/go-cmp/cmp"
)

// CompareValues compares two values to determine if they are equal.
func CompareValues(expect, actual interface{}, opts ...cmp.Option) error {
opts = append(make([]cmp.Option, 0, len(opts)+1), opts...)

var skippedReaders filterSkipDifferentIoReader

opts = append(opts,
cmp.Transformer("http.NoBody", transformHTTPNoBodyToNil),
cmp.FilterValues(skippedReaders.filter, cmp.Ignore()),
)

if diff := cmp.Diff(expect, actual, opts...); len(diff) != 0 {
return fmt.Errorf("values do not match\n%s", diff)
}

var errs []error
for _, s := range skippedReaders {
if err := CompareReaders(s.A, s.B); err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return fmt.Errorf("io.Readers have different values\n%v", errs)
}

return nil
}

func transformHTTPNoBodyToNil(v io.Reader) io.Reader {
if v == http.NoBody {
return nil
}
return v
}

type filterSkipDifferentIoReader []skippedReaders

func (f *filterSkipDifferentIoReader) filter(a, b io.Reader) bool {
if a == nil || b == nil {
return false
}
//at, bt := reflect.TypeOf(a), reflect.TypeOf(b)
//for at.Kind() == reflect.Ptr {
// at = at.Elem()
//}
//for bt.Kind() == reflect.Ptr {
// bt = bt.Elem()
//}

//// The underlying reader types are the same they can be compared directly.
//if at == bt {
// return false
//}

*f = append(*f, skippedReaders{A: a, B: b})
return true
}

type skippedReaders struct {
A, B io.Reader
}

// CompareReaders two io.Reader values together to determine if they are equal.
// Will read the contents of the readers until they are empty.
func CompareReaders(expect, actual io.Reader) error {
e, err := ioutil.ReadAll(expect)
if err != nil {
return fmt.Errorf("failed to read expect body, %w", err)
}

a, err := ioutil.ReadAll(actual)
if err != nil {
return fmt.Errorf("failed to read actual body, %w", err)
}

if !bytes.Equal(e, a) {
return fmt.Errorf("bytes do not match\nexpect:\n%s\nactual:\n%s",
hex.Dump(e), hex.Dump(a))
}

return nil
}
102 changes: 102 additions & 0 deletions testing/struct_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package testing

import (
"bytes"
"io"
"io/ioutil"
"strings"
"testing"
)

func TestCompareStructEqual(t *testing.T) {
cases := map[string]struct {
A, B interface{}
ExpectErr string
}{
"simple match": {
A: struct {
Foo string
Bar int
}{
Foo: "abc",
Bar: 123,
},
B: struct {
Foo string
Bar int
}{
Foo: "abc",
Bar: 123,
},
},
"simple diff": {
A: struct {
Foo string
Bar int
}{
Foo: "abc",
Bar: 123,
},
B: struct {
Foo string
Bar int
}{
Foo: "abc",
Bar: 456,
},
ExpectErr: "values do not match",
},
"reader match": {
A: struct {
Foo io.Reader
Bar int
}{
Foo: bytes.NewBuffer([]byte("abc123")),
Bar: 123,
},
B: struct {
Foo io.Reader
Bar int
}{
Foo: ioutil.NopCloser(strings.NewReader("abc123")),
Bar: 123,
},
},
"reader diff": {
A: struct {
Foo io.Reader
Bar int
}{
Foo: bytes.NewBuffer([]byte("abc123")),
Bar: 123,
},
B: struct {
Foo io.Reader
Bar int
}{
Foo: ioutil.NopCloser(strings.NewReader("123abc")),
Bar: 123,
},
ExpectErr: "bytes do not match",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
err := CompareValues(c.A, c.B)

if len(c.ExpectErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %v, got %v", e, a)
}
return
}
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
})
}
}

0 comments on commit 90381e4

Please sign in to comment.