diff --git a/go.mod b/go.mod index d920e36f..067eddf1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/golang/mock require ( + github.com/google/go-cmp v0.5.8 // indirect golang.org/x/mod v0.5.1 golang.org/x/tools v0.1.8 ) diff --git a/go.sum b/go.sum index 5ae13f0c..b744f39c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/yuin/goldmark v1.4.1 h1:/vn0k+RBvwlxEmP5E7SZMqNxPhfMVFEJiykr15/0XKM= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/gomock/call.go b/gomock/call.go index 98881596..38bcf264 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -19,6 +19,8 @@ import ( "reflect" "strconv" "strings" + + "github.com/google/go-cmp/cmp" ) // Call represents an expected call to a mock. @@ -42,11 +44,13 @@ type Call struct { // can set the return values by returning a non-nil slice. Actions run in the // order they are created. actions []func([]interface{}) []interface{} + + cmpOpts cmp.Options } // newCall creates a *Call. It requires the method type in order to support // unexported methods. -func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { +func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, cmpOpts cmp.Options, args ...interface{}) *Call { t.Helper() // TODO: check arity, types. @@ -76,7 +80,8 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle return rets }} return &Call{t: t, receiver: receiver, method: method, methodType: methodType, - args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} + args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions, + cmpOpts: cmpOpts} } // AnyTimes allows the expectation to be called 0 or more times @@ -331,10 +336,28 @@ func (c *Call) matches(args []interface{}) error { } for i, m := range c.args { - if !m.Matches(args[i]) { + arg := args[i] + if !m.Matches(arg) { + var sb strings.Builder + sb.WriteString( + fmt.Sprintf("expected call at %s doesn't match the argument at index %d.", c.origin, i), + ) + if g, ok := m.(GotFormatter); ok { + return fmt.Errorf( + "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v", + c.origin, i, g.Got(arg), m, + ) + } + if d, ok := m.(Differ); ok { + diff := d.Diff(arg, c.cmpOpts...) + return fmt.Errorf( + "expected call at %s doesn't match the argument at index %d.\nDiff (-want +got): %s", + c.origin, i, diff, + ) + } return fmt.Errorf( "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v", - c.origin, i, formatGottenArg(m, args[i]), m, + c.origin, i, formatGottenArg(m, arg), m, ) } } diff --git a/gomock/callset_test.go b/gomock/callset_test.go index fe053af7..0acdedcf 100644 --- a/gomock/callset_test.go +++ b/gomock/callset_test.go @@ -30,7 +30,7 @@ func TestCallSetAdd(t *testing.T) { numCalls := 10 for i := 0; i < numCalls; i++ { - cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))) + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil)) } call, err := cs.FindMatch(receiver, method, []interface{}{}) @@ -82,7 +82,7 @@ func TestCallSetFindMatch(t *testing.T) { method := "TestMethod" args := []interface{}{} - c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func)) + c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil) cs.exhausted = map[callSetKey][]*Call{ {receiver: receiver, fname: method}: {c1}, } diff --git a/gomock/controller.go b/gomock/controller.go index 5e2def13..6b4ead69 100644 --- a/gomock/controller.go +++ b/gomock/controller.go @@ -20,6 +20,8 @@ import ( "reflect" "runtime" "sync" + + "github.com/google/go-cmp/cmp" ) // A TestReporter is something that can be used to report test failures. It @@ -79,6 +81,18 @@ type Controller struct { mu sync.Mutex expectedCalls *callSet finished bool + cmpOpts cmp.Options +} + +// ControllerOption is a function that configures a Controller. +type ControllerOption func(*Controller) + +// WithCmpOpts is a ControllerOption that configures the options to pass to +// cmp.Diff. +func WithCmpOpts(opts ...cmp.Option) ControllerOption { + return func(c *Controller) { + c.cmpOpts = opts + } } // NewController returns a new Controller. It is the preferred way to create a @@ -86,7 +100,7 @@ type Controller struct { // // New in go1.14+, if you are passing a *testing.T into this function you no // longer need to call ctrl.Finish() in your test methods. -func NewController(t TestReporter) *Controller { +func NewController(t TestReporter, opts ...ControllerOption) *Controller { h, ok := t.(TestHelper) if !ok { h = &nopTestHelper{t} @@ -102,6 +116,10 @@ func NewController(t TestReporter) *Controller { }) } + for _, opt := range opts { + opt(ctrl) + } + return ctrl } @@ -165,7 +183,7 @@ func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ... func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { ctrl.T.Helper() - call := newCall(ctrl.T, receiver, method, methodType, args...) + call := newCall(ctrl.T, receiver, method, methodType, ctrl.cmpOpts, args...) ctrl.mu.Lock() defer ctrl.mu.Unlock() diff --git a/gomock/controller_test.go b/gomock/controller_test.go index 921a51f8..b781014d 100644 --- a/gomock/controller_test.go +++ b/gomock/controller_test.go @@ -17,11 +17,11 @@ package gomock_test import ( "fmt" "reflect" - "testing" - "strings" + "testing" "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp/cmpopts" ) type ErrorReporter struct { @@ -75,8 +75,11 @@ func (e *ErrorReporter) assertFatal(fn func(), expectedErrMsgs ...string) { // check the last actualErrMsg, because the previous messages come from previous errors actualErrMsg := e.log[len(e.log)-1] for _, expectedErrMsg := range expectedErrMsgs { - if !strings.Contains(actualErrMsg, expectedErrMsg) { + i := strings.Index(actualErrMsg, expectedErrMsg) + if i == -1 { e.t.Errorf("Error message:\ngot: %q\nwant to contain: %q\n", actualErrMsg, expectedErrMsg) + } else { + actualErrMsg = actualErrMsg[i+len(expectedErrMsg):] } } } @@ -150,8 +153,9 @@ func (s *Subject) VariadicMethod(arg int, vararg ...string) {} // A type purely for ActOnTestStructMethod type TestStruct struct { - Number int - Message string + Number int + Message string + secretMessage string } func (s *Subject) ActOnTestStructMethod(arg TestStruct, arg1 int) int { @@ -172,7 +176,9 @@ func createFixtures(t *testing.T) (reporter *ErrorReporter, ctrl *gomock.Control // Controller. We use it to test that the mock considered tests // successful or failed. reporter = NewErrorReporter(t) - ctrl = gomock.NewController(reporter) + ctrl = gomock.NewController( + reporter, gomock.WithCmpOpts(cmpopts.IgnoreUnexported(TestStruct{})), + ) return } @@ -289,13 +295,13 @@ func TestUnexpectedArgValue_FirstArg(t *testing.T) { // the method argument (of TestStruct type) has 1 unexpected value (for the Message field) ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 123, Message: "no message"}, 15) }, "Unexpected call to", "doesn't match the argument at index 0", - "Got: {123 no message} (gomock_test.TestStruct)\nWant: is equal to {123 hello %s} (gomock_test.TestStruct)") + "Diff (-want +got):", "gomock_test.TestStruct{", "Number: 123", "-", "Message: \"hello %s\",", "+", "Message: \"no message\",", "}") reporter.assertFatal(func() { // the method argument (of TestStruct type) has 2 unexpected values (for both fields) ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 11, Message: "no message"}, 15) }, "Unexpected call to", "doesn't match the argument at index 0", - "Got: {11 no message} (gomock_test.TestStruct)\nWant: is equal to {123 hello %s} (gomock_test.TestStruct)") + "Diff (-want +got):", "gomock_test.TestStruct{", "-", "Number: 123,", "+", "Number: 11,", "-", "Message: \"hello %s\",", "+", "Message: \"no message\",", "}") reporter.assertFatal(func() { // The expected call wasn't made. @@ -314,7 +320,7 @@ func TestUnexpectedArgValue_SecondArg(t *testing.T) { reporter.assertFatal(func() { ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 123, Message: "hello"}, 3) }, "Unexpected call to", "doesn't match the argument at index 1", - "Got: 3 (int)\nWant: is equal to 15 (int)") + "Diff (-want +got):", "int(", "-", "15,", "+", "3,", ")") reporter.assertFatal(func() { // The expected call wasn't made. diff --git a/gomock/matchers.go b/gomock/matchers.go index 2822fb2c..6b62d677 100644 --- a/gomock/matchers.go +++ b/gomock/matchers.go @@ -18,6 +18,8 @@ import ( "fmt" "reflect" "strings" + + "github.com/google/go-cmp/cmp" ) // A Matcher is a representation of a class of values. @@ -30,6 +32,11 @@ type Matcher interface { String() string } +type Differ interface { + // Diff shows the difference between the value and x. + Diff(x interface{}, opts ...cmp.Option) string +} + // WantFormatter modifies the given Matcher's String() method to the given // Stringer. This allows for control on how the "Want" is formatted when // printing . @@ -93,6 +100,10 @@ func (anyMatcher) Matches(interface{}) bool { return true } +func (anyMatcher) Diff(interface{}) string { + return "" +} + func (anyMatcher) String() string { return "is anything" } @@ -119,6 +130,10 @@ func (e eqMatcher) Matches(x interface{}) bool { return false } +func (e eqMatcher) Diff(x interface{}, opts ...cmp.Option) string { + return cmp.Diff(e.x, x, opts...) +} + func (e eqMatcher) String() string { return fmt.Sprintf("is equal to %v (%T)", e.x, e.x) } @@ -140,6 +155,10 @@ func (nilMatcher) Matches(x interface{}) bool { return false } +func (nilMatcher) Diff(x interface{}, opts ...cmp.Option) string { + return cmp.Diff(nil, x, opts...) +} + func (nilMatcher) String() string { return "is nil" } @@ -164,6 +183,10 @@ func (m assignableToTypeOfMatcher) Matches(x interface{}) bool { return reflect.TypeOf(x).AssignableTo(m.targetType) } +func (m assignableToTypeOfMatcher) Diff(x interface{}, opts ...cmp.Option) string { + return cmp.Diff(m.targetType, reflect.TypeOf(x), opts...) +} + func (m assignableToTypeOfMatcher) String() string { return "is assignable to " + m.targetType.Name() } @@ -181,6 +204,18 @@ func (am allMatcher) Matches(x interface{}) bool { return true } +func (am allMatcher) Diff(x interface{}, opts ...cmp.Option) string { + ss := make([]string, 0, len(am.matchers)) + for _, matcher := range am.matchers { + if d, ok := matcher.(Differ); ok { + ss = append(ss, d.Diff(x)) + } else { + ss = append(ss, matcher.String()) + } + } + return strings.Join(ss, "; ") +} + func (am allMatcher) String() string { ss := make([]string, 0, len(am.matchers)) for _, matcher := range am.matchers { @@ -203,6 +238,16 @@ func (m lenMatcher) Matches(x interface{}) bool { } } +func (m lenMatcher) Diff(x interface{}, opts ...cmp.Option) string { + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return cmp.Diff(m.i, v.Len(), opts...) + default: + return cmp.Diff(m.i, fmt.Sprintf("invalid: len(%T)", x), opts...) + } +} + func (m lenMatcher) String() string { return fmt.Sprintf("has length %d", m.i) } @@ -257,6 +302,52 @@ func (m inAnyOrderMatcher) Matches(x interface{}) bool { return extraInGiven == 0 && missingFromWanted == 0 } +func (m inAnyOrderMatcher) Diff(x interface{}, opts ...cmp.Option) string { + given, ok := m.prepareValue(x) + if !ok { + return cmp.Diff(m.x, x, opts...) + } + wanted, ok := m.prepareValue(m.x) + if !ok { + return cmp.Diff(m.x, x, opts...) + } + + if given.Len() != wanted.Len() { + return cmp.Diff(m.x, x, opts...) + } + + usedFromGiven := make([]bool, given.Len()) + foundFromWanted := make([]bool, wanted.Len()) + for i := 0; i < wanted.Len(); i++ { + wantedMatcher := Eq(wanted.Index(i).Interface()) + for j := 0; j < given.Len(); j++ { + if usedFromGiven[j] { + continue + } + if wantedMatcher.Matches(given.Index(j).Interface()) { + foundFromWanted[i] = true + usedFromGiven[j] = true + break + } + } + } + + missingFromWanted := 0 + for _, found := range foundFromWanted { + if !found { + missingFromWanted++ + } + } + extraInGiven := 0 + for _, used := range usedFromGiven { + if !used { + extraInGiven++ + } + } + + return cmp.Diff(m.x, x, opts...) +} + func (m inAnyOrderMatcher) prepareValue(x interface{}) (reflect.Value, bool) { xValue := reflect.ValueOf(x) switch xValue.Kind() {