@@ -18,13 +18,15 @@ import (
18
18
"bytes"
19
19
"errors"
20
20
"fmt"
21
+ "sync"
21
22
)
22
23
23
24
// callSet represents a set of expected calls, indexed by receiver and method
24
25
// name.
25
26
type callSet struct {
26
27
// Calls that are still expected.
27
- expected map [callSetKey ][]* Call
28
+ expected map [callSetKey ][]* Call
29
+ expectedMu * sync.Mutex
28
30
// Calls that have been exhausted.
29
31
exhausted map [callSetKey ][]* Call
30
32
}
@@ -36,12 +38,20 @@ type callSetKey struct {
36
38
}
37
39
38
40
func newCallSet () * callSet {
39
- return & callSet {make (map [callSetKey ][]* Call ), make (map [callSetKey ][]* Call )}
41
+ return & callSet {
42
+ expected : make (map [callSetKey ][]* Call ),
43
+ expectedMu : & sync.Mutex {},
44
+ exhausted : make (map [callSetKey ][]* Call ),
45
+ }
40
46
}
41
47
42
48
// Add adds a new expected call.
43
49
func (cs callSet ) Add (call * Call ) {
44
50
key := callSetKey {call .receiver , call .method }
51
+
52
+ cs .expectedMu .Lock ()
53
+ defer cs .expectedMu .Unlock ()
54
+
45
55
m := cs .expected
46
56
if call .exhausted () {
47
57
m = cs .exhausted
@@ -52,6 +62,10 @@ func (cs callSet) Add(call *Call) {
52
62
// Remove removes an expected call.
53
63
func (cs callSet ) Remove (call * Call ) {
54
64
key := callSetKey {call .receiver , call .method }
65
+
66
+ cs .expectedMu .Lock ()
67
+ defer cs .expectedMu .Unlock ()
68
+
55
69
calls := cs .expected [key ]
56
70
for i , c := range calls {
57
71
if c == call {
@@ -67,6 +81,9 @@ func (cs callSet) Remove(call *Call) {
67
81
func (cs callSet ) FindMatch (receiver interface {}, method string , args []interface {}) (* Call , error ) {
68
82
key := callSetKey {receiver , method }
69
83
84
+ cs .expectedMu .Lock ()
85
+ defer cs .expectedMu .Unlock ()
86
+
70
87
// Search through the expected calls.
71
88
expected := cs .expected [key ]
72
89
var callsErrors bytes.Buffer
@@ -101,6 +118,9 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac
101
118
102
119
// Failures returns the calls that are not satisfied.
103
120
func (cs callSet ) Failures () []* Call {
121
+ cs .expectedMu .Lock ()
122
+ defer cs .expectedMu .Unlock ()
123
+
104
124
failures := make ([]* Call , 0 , len (cs .expected ))
105
125
for _ , calls := range cs .expected {
106
126
for _ , call := range calls {
@@ -111,3 +131,19 @@ func (cs callSet) Failures() []*Call {
111
131
}
112
132
return failures
113
133
}
134
+
135
+ // Satisfied returns true in case all expected calls in this callSet are satisfied.
136
+ func (cs callSet ) Satisfied () bool {
137
+ cs .expectedMu .Lock ()
138
+ defer cs .expectedMu .Unlock ()
139
+
140
+ for _ , calls := range cs .expected {
141
+ for _ , call := range calls {
142
+ if ! call .satisfied () {
143
+ return false
144
+ }
145
+ }
146
+ }
147
+
148
+ return true
149
+ }
0 commit comments