From 90381e4e4f269d34ef0c42e1435140a06b4e36e0 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte Date: Fri, 17 Jul 2020 17:50:44 -0700 Subject: [PATCH] Fix protocol test generation for streaming types Fixes the generation of protocol tests with streaming input and output members. Refactors the comparison of structure values to use a utility that flexibly compares the io.Reader values. Allowing the underlying io.Reader implementations to be different types, but compare as equal if the contents of the readers are the same. --- .../go/codegen/ShapeValueGenerator.java | 27 ++++- .../HttpProtocolUnitTestGenerator.java | 9 +- io/reader.go | 16 +++ testing/struct.go | 95 ++++++++++++++++ testing/struct_test.go | 102 ++++++++++++++++++ 5 files changed, 240 insertions(+), 9 deletions(-) create mode 100644 io/reader.go create mode 100644 testing/struct.go create mode 100644 testing/struct_test.go diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java index ff4382b04..f4121e54f 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java @@ -37,6 +37,7 @@ import software.amazon.smithy.model.shapes.ShapeType; import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.model.traits.EnumTrait; +import software.amazon.smithy.model.traits.StreamingTrait; /** * Generates a shape type declaration based on the parameters provided. @@ -169,6 +170,7 @@ protected void mapDeclShapeValue(GoWriter writer, MapShape shape, Runnable inner */ protected void scalarWrapShapeValue(GoWriter writer, Shape shape, Runnable inner) { boolean withPtrImport = true; + String closing = ")"; switch (shape.getType()) { case BOOLEAN: @@ -176,20 +178,34 @@ protected void scalarWrapShapeValue(GoWriter writer, Shape shape, Runnable inner break; case BLOB: - writer.writeInline("[]byte("); + if (shape.hasTrait(StreamingTrait.class)) { + writer.addUseImports(SmithyGoDependency.SMITHY_IO); + writer.addUseImports(SmithyGoDependency.BYTES); + writer.writeInline("smithyio.ReadSeekNopCloser{ReadSeeker: bytes.NewReader([]byte("); + closing += ")}"; + } else { + writer.writeInline("[]byte("); + } withPtrImport = false; break; case STRING: // Enum are not pointers, but string alias values - if (shape.hasTrait(EnumTrait.class)) { + if (shape.hasTrait(StreamingTrait.class)) { + writer.addUseImports(SmithyGoDependency.SMITHY_IO); + writer.addUseImports(SmithyGoDependency.STRINGS); + writer.writeInline("smithyio.ReadSeekNopCloser{ReadSeeker: strings.NewReader("); + closing += "}"; + + } else if (shape.hasTrait(EnumTrait.class)) { Symbol enumSymbol = symbolProvider.toSymbol(shape); writer.writeInline("$T(", enumSymbol); withPtrImport = false; - break; + + } else { + writer.writeInline("ptr.String("); } - writer.writeInline("ptr.String("); break; case TIMESTAMP: @@ -232,8 +248,9 @@ protected void scalarWrapShapeValue(GoWriter writer, Shape shape, Runnable inner if (withPtrImport) { writer.addUseImports(SmithyGoDependency.SMITHY_PTR); } + inner.run(); - writer.writeInline(")"); + writer.writeInline(closing); } /** diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java index 53a2dca33..81bff2746 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java @@ -402,16 +402,17 @@ protected void writeAssertScalarEqual(GoWriter writer, String expect, String act protected void writeAssertComplexEqual( GoWriter writer, String expect, String actual, String[] ignoreTypes ) { - writer.addUseImports(SmithyGoDependency.GO_CMP); + writer.addUseImports(SmithyGoDependency.SMITHY_TESTING); writer.addUseImports(SmithyGoDependency.GO_CMP_OPTIONS); - writer.writeInline("if diff := cmp.Diff($L, $L, cmpopts.IgnoreUnexported(", expect, actual); + + writer.writeInline("if err := smithytesting.CompareValues($L, $L, cmpopts.IgnoreUnexported(", expect, actual); for (String ignoreType : ignoreTypes) { writer.write("$L,", ignoreType); } - writer.writeInline(")); len(diff) != 0 {"); - writer.write(" t.Errorf(\"expect $L value match:\\n%s\", diff)", expect); + writer.writeInline(")); err != nil {"); + writer.write(" t.Errorf(\"expect $L value match:\\n%v\", err)", expect); writer.write("}"); } diff --git a/io/reader.go b/io/reader.go new file mode 100644 index 000000000..07063f296 --- /dev/null +++ b/io/reader.go @@ -0,0 +1,16 @@ +package io + +import ( + "io" +) + +// ReadSeekNopCloser wraps an io.ReadSeeker with an additional Close method +// that does nothing. +type ReadSeekNopCloser struct { + io.ReadSeeker +} + +// Close does nothing. +func (ReadSeekNopCloser) Close() error { + return nil +} diff --git a/testing/struct.go b/testing/struct.go new file mode 100644 index 000000000..51f655269 --- /dev/null +++ b/testing/struct.go @@ -0,0 +1,95 @@ +package testing + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/google/go-cmp/cmp" +) + +// CompareValues compares two values to determine if they are equal. +func CompareValues(expect, actual interface{}, opts ...cmp.Option) error { + opts = append(make([]cmp.Option, 0, len(opts)+1), opts...) + + var skippedReaders filterSkipDifferentIoReader + + opts = append(opts, + cmp.Transformer("http.NoBody", transformHTTPNoBodyToNil), + cmp.FilterValues(skippedReaders.filter, cmp.Ignore()), + ) + + if diff := cmp.Diff(expect, actual, opts...); len(diff) != 0 { + return fmt.Errorf("values do not match\n%s", diff) + } + + var errs []error + for _, s := range skippedReaders { + if err := CompareReaders(s.A, s.B); err != nil { + errs = append(errs, err) + } + } + if len(errs) != 0 { + return fmt.Errorf("io.Readers have different values\n%v", errs) + } + + return nil +} + +func transformHTTPNoBodyToNil(v io.Reader) io.Reader { + if v == http.NoBody { + return nil + } + return v +} + +type filterSkipDifferentIoReader []skippedReaders + +func (f *filterSkipDifferentIoReader) filter(a, b io.Reader) bool { + if a == nil || b == nil { + return false + } + //at, bt := reflect.TypeOf(a), reflect.TypeOf(b) + //for at.Kind() == reflect.Ptr { + // at = at.Elem() + //} + //for bt.Kind() == reflect.Ptr { + // bt = bt.Elem() + //} + + //// The underlying reader types are the same they can be compared directly. + //if at == bt { + // return false + //} + + *f = append(*f, skippedReaders{A: a, B: b}) + return true +} + +type skippedReaders struct { + A, B io.Reader +} + +// CompareReaders two io.Reader values together to determine if they are equal. +// Will read the contents of the readers until they are empty. +func CompareReaders(expect, actual io.Reader) error { + e, err := ioutil.ReadAll(expect) + if err != nil { + return fmt.Errorf("failed to read expect body, %w", err) + } + + a, err := ioutil.ReadAll(actual) + if err != nil { + return fmt.Errorf("failed to read actual body, %w", err) + } + + if !bytes.Equal(e, a) { + return fmt.Errorf("bytes do not match\nexpect:\n%s\nactual:\n%s", + hex.Dump(e), hex.Dump(a)) + } + + return nil +} diff --git a/testing/struct_test.go b/testing/struct_test.go new file mode 100644 index 000000000..be92c76de --- /dev/null +++ b/testing/struct_test.go @@ -0,0 +1,102 @@ +package testing + +import ( + "bytes" + "io" + "io/ioutil" + "strings" + "testing" +) + +func TestCompareStructEqual(t *testing.T) { + cases := map[string]struct { + A, B interface{} + ExpectErr string + }{ + "simple match": { + A: struct { + Foo string + Bar int + }{ + Foo: "abc", + Bar: 123, + }, + B: struct { + Foo string + Bar int + }{ + Foo: "abc", + Bar: 123, + }, + }, + "simple diff": { + A: struct { + Foo string + Bar int + }{ + Foo: "abc", + Bar: 123, + }, + B: struct { + Foo string + Bar int + }{ + Foo: "abc", + Bar: 456, + }, + ExpectErr: "values do not match", + }, + "reader match": { + A: struct { + Foo io.Reader + Bar int + }{ + Foo: bytes.NewBuffer([]byte("abc123")), + Bar: 123, + }, + B: struct { + Foo io.Reader + Bar int + }{ + Foo: ioutil.NopCloser(strings.NewReader("abc123")), + Bar: 123, + }, + }, + "reader diff": { + A: struct { + Foo io.Reader + Bar int + }{ + Foo: bytes.NewBuffer([]byte("abc123")), + Bar: 123, + }, + B: struct { + Foo io.Reader + Bar int + }{ + Foo: ioutil.NopCloser(strings.NewReader("123abc")), + Bar: 123, + }, + ExpectErr: "bytes do not match", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + err := CompareValues(c.A, c.B) + + if len(c.ExpectErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +}