diff --git a/mock.go b/mock.go index 38a94bf..11500d0 100644 --- a/mock.go +++ b/mock.go @@ -17,15 +17,10 @@ const ( ) var ( - mu sync.RWMutex - cache map[reflect.Type]cachedMockType = map[reflect.Type]cachedMockType{} + mu sync.Mutex + cache map[reflect.Type]func() (any, *mock.Mock) ) -type cachedMockType struct { - mockType reflect.Type - contexts unsafe.Pointer -} - // Mock takes a T interface type and returns a runtime generated mock implementation for it // plus a testify mock.Mock object to control its behavior func Mock[T any]() (T, *mock.Mock) { @@ -34,15 +29,15 @@ func Mock[T any]() (T, *mock.Mock) { panic("not an interface type: " + typ.Name()) } - mu.RLock() + mu.Lock() + defer mu.Unlock() + if cache == nil { + cache = make(map[reflect.Type]func() (any, *mock.Mock)) + } cached, ok := cache[typ] - mu.RUnlock() if ok { - m := new(mock.Mock) - obj := reflect.New(cached.mockType).Elem() - obj.Field(0).Set(reflect.ValueOf(cached.contexts)) - obj.Field(1).Set(reflect.ValueOf(m)) - return obj.Interface().(T), m + obj, m := cached() + return obj.(T), m } mockType := reflect.StructOf([]reflect.StructField{ @@ -64,23 +59,22 @@ func Mock[T any]() (T, *mock.Mock) { abiMethods := getAbiMethods(mockType) for i := range abiMethods { fn := mockMethod(typ.Method(i)) - vptr := (*(*[2]unsafe.Pointer)(unsafe.Pointer(&fn)))[1] - contexts[i] = vptr + contexts[i] = unpackIfaceData(fn) abiMethods[i].setFn(mockMethodWrappers[i]) } - contextsPtr := unsafe.Pointer(unsafe.SliceData(contexts)) - mu.Lock() - cache[typ] = cachedMockType{ - mockType: mockType, - contexts: contextsPtr, + + newMock := func() (any, *mock.Mock) { + contextsPtr := unsafe.Pointer(unsafe.SliceData(contexts)) + m := new(mock.Mock) + obj := reflect.New(mockType).Elem() + obj.Field(0).Set(reflect.ValueOf(contextsPtr)) + obj.Field(1).Set(reflect.ValueOf(m)) + return obj.Interface(), m } - mu.Unlock() + cache[typ] = newMock - m := new(mock.Mock) - obj := reflect.New(mockType).Elem() - obj.Field(0).Set(reflect.ValueOf(contextsPtr)) - obj.Field(1).Set(reflect.ValueOf(m)) - return obj.Interface().(T), m + obj, m := newMock() + return obj.(T), m } func mockMethod(method reflect.Method) any { @@ -138,6 +132,10 @@ func getOutTypes(method reflect.Method) []reflect.Type { return types } +func unpackIfaceData(iface any) unsafe.Pointer { + return (*(*[2]unsafe.Pointer)(unsafe.Pointer(&iface)))[1] +} + var mockMethodWrappers = []func(){ m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15,