Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
razzie committed Jan 22, 2025
1 parent d84c746 commit 01e7290
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@ const (
)

var (
mu sync.RWMutex
cache map[reflect.Type]cachedMockType = map[reflect.Type]cachedMockType{}
mu sync.Mutex
cache map[reflect.Type]func() (any, *mock.Mock)
)

type cachedMockType struct {
mockType reflect.Type
contexts unsafe.Pointer
}

// Mock takes a T interface type and returns a runtime generated mock implementation for it
// plus a testify mock.Mock object to control its behavior
func Mock[T any]() (T, *mock.Mock) {
Expand All @@ -34,15 +29,15 @@ func Mock[T any]() (T, *mock.Mock) {
panic("not an interface type: " + typ.Name())
}

mu.RLock()
mu.Lock()
defer mu.Unlock()
if cache == nil {
cache = make(map[reflect.Type]func() (any, *mock.Mock))
}
cached, ok := cache[typ]
mu.RUnlock()
if ok {
m := new(mock.Mock)
obj := reflect.New(cached.mockType).Elem()
obj.Field(0).Set(reflect.ValueOf(cached.contexts))
obj.Field(1).Set(reflect.ValueOf(m))
return obj.Interface().(T), m
obj, m := cached()
return obj.(T), m
}

mockType := reflect.StructOf([]reflect.StructField{
Expand All @@ -64,23 +59,22 @@ func Mock[T any]() (T, *mock.Mock) {
abiMethods := getAbiMethods(mockType)
for i := range abiMethods {
fn := mockMethod(typ.Method(i))
vptr := (*(*[2]unsafe.Pointer)(unsafe.Pointer(&fn)))[1]
contexts[i] = vptr
contexts[i] = unpackIfaceData(fn)
abiMethods[i].setFn(mockMethodWrappers[i])
}
contextsPtr := unsafe.Pointer(unsafe.SliceData(contexts))
mu.Lock()
cache[typ] = cachedMockType{
mockType: mockType,
contexts: contextsPtr,

newMock := func() (any, *mock.Mock) {
contextsPtr := unsafe.Pointer(unsafe.SliceData(contexts))
m := new(mock.Mock)
obj := reflect.New(mockType).Elem()
obj.Field(0).Set(reflect.ValueOf(contextsPtr))
obj.Field(1).Set(reflect.ValueOf(m))
return obj.Interface(), m
}
mu.Unlock()
cache[typ] = newMock

m := new(mock.Mock)
obj := reflect.New(mockType).Elem()
obj.Field(0).Set(reflect.ValueOf(contextsPtr))
obj.Field(1).Set(reflect.ValueOf(m))
return obj.Interface().(T), m
obj, m := newMock()
return obj.(T), m
}

func mockMethod(method reflect.Method) any {
Expand Down Expand Up @@ -138,6 +132,10 @@ func getOutTypes(method reflect.Method) []reflect.Type {
return types
}

func unpackIfaceData(iface any) unsafe.Pointer {
return (*(*[2]unsafe.Pointer)(unsafe.Pointer(&iface)))[1]
}

var mockMethodWrappers = []func(){
m0, m1, m2, m3, m4, m5, m6, m7,
m8, m9, m10, m11, m12, m13, m14, m15,
Expand Down

0 comments on commit 01e7290

Please sign in to comment.