diff --git a/gomock/call.go b/gomock/call.go index 98881596..1b9ab749 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -30,6 +30,7 @@ type Call struct { methodType reflect.Type // the type of the method args []Matcher // the args origin string // file and line number of call setup + isDefault bool // true if this is a default call preReqs []*Call // prerequisite calls @@ -79,8 +80,29 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} } +// ByDefault defines this expectation as a default that is tried to match if no other expectation matches. +// To use this as catch-all you may use Any() matcher for the method parameters. +// ByDefault expects to be called 0 or more times. +func (c *Call) ByDefault() *Call { + if c.minCalls != 1 || c.maxCalls != 1 { + c.t.Fatalf("MinTimes(), MaxTimes(), Times() or AnyTimes() is not allowed when using ByDefault()") + } + + if len(c.preReqs) != 0 { + c.t.Fatalf("After() is not allowed when using ByDefault()") + } + + c.isDefault = true + c.minCalls, c.maxCalls = 0, 1e8 + return c +} + // AnyTimes allows the expectation to be called 0 or more times func (c *Call) AnyTimes() *Call { + if c.isDefault { + c.t.Fatalf("AnyTimes() is not allowed when using ByDefault()") + } + c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity return c } @@ -88,6 +110,10 @@ func (c *Call) AnyTimes() *Call { // MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called or if MaxTimes // was previously called with 1, MinTimes also sets the maximum number of calls to infinity. func (c *Call) MinTimes(n int) *Call { + if c.isDefault { + c.t.Fatalf("MinTimes() is not allowed when using ByDefault()") + } + c.minCalls = n if c.maxCalls == 1 { c.maxCalls = 1e8 @@ -98,6 +124,10 @@ func (c *Call) MinTimes(n int) *Call { // MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called or if MinTimes was // previously called with 1, MaxTimes also sets the minimum number of calls to 0. func (c *Call) MaxTimes(n int) *Call { + if c.isDefault { + c.t.Fatalf("MaxTimes() is not allowed when using ByDefault()") + } + c.maxCalls = n if c.minCalls == 1 { c.minCalls = 0 @@ -224,6 +254,10 @@ func (c *Call) Return(rets ...interface{}) *Call { // Times declares the exact number of times a function call is expected to be executed. func (c *Call) Times(n int) *Call { + if c.isDefault { + c.t.Fatalf("Times() is not allowed when using ByDefault()") + } + c.minCalls, c.maxCalls = n, n return c } @@ -291,6 +325,15 @@ func (c *Call) isPreReq(other *Call) bool { func (c *Call) After(preReq *Call) *Call { c.t.Helper() + // this is more or less a hint, since you could add a call as prerequisite + // and afterwards use ByDefault() + if preReq.isDefault { + c.t.Fatalf("Default isn't allowed to be a prerequisite") + } + if c.isDefault { + c.t.Fatalf("ByDefault() isn't allowed to have prerequisites") + } + if c == preReq { c.t.Fatalf("A call isn't allowed to be its own prerequisite") } diff --git a/gomock/callset.go b/gomock/callset.go index 49dba787..0ec79380 100644 --- a/gomock/callset.go +++ b/gomock/callset.go @@ -70,15 +70,25 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac // Search through the expected calls. expected := cs.expected[key] var callsErrors bytes.Buffer + var defaultCall *Call for _, call := range expected { err := call.matches(args) if err != nil { _, _ = fmt.Fprintf(&callsErrors, "\n%v", err) } else { + if call.isDefault { + defaultCall = call + continue + } return call, nil } } + // Nothing found check if we at least found some default + if defaultCall != nil { + return defaultCall, nil + } + // If we haven't found a match then search through the exhausted calls so we // get useful error messages. exhausted := cs.exhausted[key] diff --git a/gomock/controller_test.go b/gomock/controller_test.go index 921a51f8..c8dadaca 100644 --- a/gomock/controller_test.go +++ b/gomock/controller_test.go @@ -826,6 +826,266 @@ func TestVariadicArgumentsGotFormatterTooManyArgsFailure(t *testing.T) { ctrl.Finish() } +// Test ByDefault call that is used to define a default behavior + +func TestByDefaultOnly(t *testing.T) { + // no call + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "something").ByDefault().Return(5) + ctrl.Finish() + + reporter.assertPass("not calling a function with defined default is ok") + + // multiple arbitrary calls + reporter, ctrl = createFixtures(t) + subject = new(Subject) + + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().Return(5) + + rets := ctrl.Call(subject, "FooMethod", "123") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 5 { + t.Errorf("Return value is %v want 5", ret) + } + + rets = ctrl.Call(subject, "FooMethod", "") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 5 { + t.Errorf("Return value is %v want 5", ret) + } + + ctrl.Finish() + reporter.assertPass("calling a function with defined default n times with arbitrary parameters is ok") +} + +func TestByDefaultAndExpectationWithMissingCall(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().Return(5) + ctrl.RecordCall(subject, "FooMethod", "123").Return(123) + + // does call default and not match expectation + ctrl.Call(subject, "FooMethod", "should default") + + // does call default and not match expectation + ctrl.Call(subject, "FooMethod", "also default") + + reporter.assertFatal(func() { + ctrl.Finish() + }, "aborting test due to missing call(s)") +} + +func TestByDefaultAndExpectationWithAllExpectationsMet(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + // every expectation should have precedence over default call + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().Return(5) + ctrl.RecordCall(subject, "FooMethod", "123").Return(123) + ctrl.RecordCall(subject, "FooMethod", "345").Return(345) + + // should not use default but match expectation + rets := ctrl.Call(subject, "FooMethod", "123") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 123 { + t.Errorf("Return value is %v want 123", ret) + } + + // should call default since expectation is consumed + rets = ctrl.Call(subject, "FooMethod", "123") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 5 { + t.Errorf("Return value is %v want 5", ret) + } + + // should not call default but match expectation + rets = ctrl.Call(subject, "FooMethod", "345") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 345 { + t.Errorf("Return value is %v want 345", ret) + } + + // should call default since expectation is consumed + rets = ctrl.Call(subject, "FooMethod", "345") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 5 { + t.Errorf("Return value is %v want 5", ret) + } + + ctrl.Finish() + reporter.assertPass("expectations should have precedence over default call") +} + +func TestOverwriteByDefault(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + // first defaultCall + ctrl.RecordCall(subject, "FooMethod", "defaultCall").ByDefault().Return(123) + + // uses current default + rets := ctrl.Call(subject, "FooMethod", "defaultCall") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 123 { + t.Errorf("Return value is %v want 123", ret) + } + + // overwrite default (when second one matches at least all parameters that first one matched) + ctrl.RecordCall(subject, "FooMethod", "defaultCall").ByDefault().Return(456) + + // matches new default + rets = ctrl.Call(subject, "FooMethod", "defaultCall") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 456 { + t.Errorf("Return value is %v want 456", ret) + } + + // overwrite default with more loose one + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().Return(789) + + // matches new default + rets = ctrl.Call(subject, "FooMethod", "defaultCall") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 789 { + t.Errorf("Return value is %v want 789", ret) + } + + ctrl.Finish() + reporter.assertPass("should always take the latest default definition") +} + +func TestByDefaultWithMissingReturn(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault() + + rets := ctrl.Call(subject, "FooMethod", "something") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 0 { + t.Errorf("Return value is %v want 0", ret) + } + + ctrl.Finish() + reporter.assertPass("should return zero value on missing return definition") +} + +func TestByDefaultMinMaxTimesNotAllowed(t *testing.T) { + // test MinTimes not allowed + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + reporter.assertFatal(func() { + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().MinTimes(2).Return(5) + }, "MinTimes() is not allowed when using ByDefault()") + + // test MaxTimes not allowed + reporter, ctrl = createFixtures(t) + subject = new(Subject) + + reporter.assertFatal(func() { + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().MaxTimes(2).Return(5) + }, "MaxTimes() is not allowed when using ByDefault()") + + // test AnyTimes not allowed + reporter, ctrl = createFixtures(t) + subject = new(Subject) + + reporter.assertFatal(func() { + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().AnyTimes().Return(5) + }, "AnyTimes() is not allowed when using ByDefault()") + + // test Times not allowed + reporter, ctrl = createFixtures(t) + subject = new(Subject) + + reporter.assertFatal(func() { + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().Times(4).Return(5) + }, "Times() is not allowed when using ByDefault()") +} + +func TestByDefaultAfterNotAllowed(t *testing.T) { + // test After not allowed + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + someCall := ctrl.RecordCall(subject, "FooMethod", "123").Return(123) + + reporter.assertFatal(func() { + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().After(someCall) + }, "ByDefault() isn't allowed to have prerequisites") + + // test default call used as prerequisite not allowed + reporter, ctrl = createFixtures(t) + subject = new(Subject) + + defaultCall := ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault() + + reporter.assertFatal(func() { + ctrl.RecordCall(subject, "FooMethod", "123").Return(123).After(defaultCall) + }, "Default isn't allowed to be a prerequisite") +} + +func TestByDefaultCallsDoFunc(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + str := "" + + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().Do(func(s string) { + str = s + }) + + ctrl.Call(subject, "FooMethod", "something") + + if str != "something" { + t.Errorf("value is %v want 'something'", str) + } + + ctrl.Finish() + reporter.assertPass("defaultCall should ignore After()") +} + +func TestByDefaultCallsDoAndReturnFunc(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + str := "" + + ctrl.RecordCall(subject, "FooMethod", gomock.Any()).ByDefault().DoAndReturn(func(s string) int { + str = s + return 5 + }) + + rets := ctrl.Call(subject, "FooMethod", "something") + if ret, ok := rets[0].(int); !ok { + t.Fatalf("Return value is not an int") + } else if ret != 5 { + t.Errorf("Return value is %v want 5", ret) + } + + if str != "something" { + t.Errorf("value is %v want 'something'", str) + } + + ctrl.Finish() + reporter.assertPass("defaultCall work with DoAndReturn") + +} + func TestNoHelper(t *testing.T) { ctrlNoHelper := gomock.NewController(NewErrorReporter(t))