diff --git a/inject.go b/inject.go index dc7342f..04b777d 100644 --- a/inject.go +++ b/inject.go @@ -49,12 +49,13 @@ type TypeMapper interface { Set(reflect.Type, reflect.Value) TypeMapper // Returns the Value that is mapped to the current type. Returns a zeroed Value if // the Type has not been mapped. - Get(reflect.Type) reflect.Value + Get(reflect.Type, ...bool) reflect.Value } type injector struct { - values map[reflect.Type]reflect.Value - parent Injector + factories map[reflect.Type]reflect.Value + values map[reflect.Type]reflect.Value + parent Injector } // InterfaceOf dereferences a pointer to an Interface type. @@ -76,7 +77,8 @@ func InterfaceOf(value interface{}) reflect.Type { // New returns a new Injector. func New() Injector { return &injector{ - values: make(map[reflect.Type]reflect.Value), + values: make(map[reflect.Type]reflect.Value), + factories: make(map[reflect.Type]reflect.Value), } } @@ -85,27 +87,31 @@ func New() Injector { // Returns a slice of reflect.Value representing the returned values of the function. // Returns an error if the injection fails. // It panics if f is not a function -func (inj *injector) Invoke(f interface{}) ([]reflect.Value, error) { - t := reflect.TypeOf(f) +func (inj *injector) Invoke(f interface{}) (ret []reflect.Value, err error) { + defer recoverResolvePanic(&err) - var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func - for i := 0; i < t.NumIn(); i++ { + t := reflect.TypeOf(f) + var args = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func + for i, arg := range args { argType := t.In(i) - val := inj.Get(argType) - if !val.IsValid() { + arg = inj.Get(argType) + if !arg.IsValid() { return nil, fmt.Errorf("Value not found for type %v", argType) } - in[i] = val + args[i] = arg } - return reflect.ValueOf(f).Call(in), nil + ret = reflect.ValueOf(f).Call(args) + return } // Maps dependencies in the Type map to each field in the struct // that is tagged with 'inject'. // Returns an error if the injection fails. -func (inj *injector) Apply(val interface{}) error { +func (inj *injector) Apply(val interface{}) (err error) { + defer recoverResolvePanic(&err) + v := reflect.ValueOf(val) for v.Kind() == reflect.Ptr { @@ -133,19 +139,17 @@ func (inj *injector) Apply(val interface{}) error { } - return nil + return } // Maps the concrete value of val to its dynamic type using reflect.TypeOf, // It returns the TypeMapper registered in. func (i *injector) Map(val interface{}) TypeMapper { - i.values[reflect.TypeOf(val)] = reflect.ValueOf(val) - return i + return i.mapping(reflect.TypeOf(val), reflect.ValueOf(val)) } func (i *injector) MapTo(val interface{}, ifacePtr interface{}) TypeMapper { - i.values[InterfaceOf(ifacePtr)] = reflect.ValueOf(val) - return i + return i.mapping(InterfaceOf(ifacePtr), reflect.ValueOf(val)) } // Maps the given reflect.Type to the given reflect.Value and returns @@ -155,14 +159,72 @@ func (i *injector) Set(typ reflect.Type, val reflect.Value) TypeMapper { return i } -func (i *injector) Get(t reflect.Type) reflect.Value { - val := i.values[t] - if !val.IsValid() && i.parent != nil { - val = i.parent.Get(t) +func (inj *injector) mapping(typ reflect.Type, val reflect.Value) TypeMapper { + kind := val.Kind() + if kind == reflect.Func { + if typ.Kind() == reflect.Func { + typ = val.Type().Out(0) + } + inj.factories[typ] = val + } else { + inj.values[typ] = val } - return val + return inj +} + +func (inj *injector) Get(want reflect.Type, root ...bool) (ret reflect.Value) { + ret = inj.values[want] + + if !ret.IsValid() { + ret = inj.factories[want] + } + + if !ret.IsValid() && inj.parent != nil { + ret = inj.parent.Get(want, false) + } + + if ret.IsValid() && ret.Kind() == reflect.Func && (len(root) == 0 || root[0]) { + ret = inj.resolve(ret, []reflect.Type{}) + } + + return +} + +func (inj *injector) resolve(fac reflect.Value, chain []reflect.Type) reflect.Value { + chainLen := len(chain) + if chainLen > 1 { + want := chain[chainLen-1] + for _, t := range chain[:chainLen-1] { + if want == t { + panic(resolveError{chain: chain, message: "dependency loop"}) + } + } + } + + facType := fac.Type() + args := make([]reflect.Value, facType.NumIn()) + for i, _ := range args { + argType := facType.In(i) + + if cachedVal, ok := inj.values[argType]; ok { + args[i] = cachedVal + continue + } + + args[i] = inj.Get(argType, false) + if !args[i].IsValid() { + panic(resolveError{chain: chain, fac: fac}) + } + + if args[i].Kind() == reflect.Func { + args[i] = inj.resolve(args[i], append(chain, facType, argType)) + } + inj.values[argType] = args[i] + } + + return fac.Call(args)[0] } -func (i *injector) SetParent(parent Injector) { - i.parent = parent +func (inj *injector) SetParent(parent Injector) { + inj.parent = parent } diff --git a/inject_err.go b/inject_err.go new file mode 100644 index 0000000..fa75968 --- /dev/null +++ b/inject_err.go @@ -0,0 +1,42 @@ +package inject + +import ( + "fmt" + "reflect" + "strings" +) + +type resolveError struct { + chain []reflect.Type + fac reflect.Value + message string +} + +func recoverResolvePanic(err *error) { + if r := recover(); r != nil { + switch x := r.(type) { + case resolveError: + chain := make([]string, len(x.chain)) + for i, t := range x.chain { + chain[i] = fmt.Sprintf("%q", t) + } + + if x.message == "" && x.fac.IsValid() { + m := "factory 'func(" + facType := x.fac.Type() + for i := 0; i < facType.NumIn(); { + m += fmt.Sprintf("%v", facType.In(i)) + i += 1 + } + m += ") " + fmt.Sprintf("%v", facType.Out(0)) + "'" + x.message = m + } + + *err = fmt.Errorf("Value not found for type %v (%v): %v", + chain[len(chain)-1], x.message, strings.Join(chain, " -> ")) + println(fmt.Sprintf("%v", *err)) + default: + *err = fmt.Errorf("%v", x) + } + } +} diff --git a/inject_test.go b/inject_test.go index acca573..a0f22de 100644 --- a/inject_test.go +++ b/inject_test.go @@ -1,9 +1,12 @@ package inject_test import ( + "fmt" "github.com/codegangsta/inject" + "math/rand" "reflect" "testing" + "time" ) type SpecialString interface { @@ -15,6 +18,10 @@ type TestStruct struct { Dep3 string } +func init() { + rand.Seed(time.Now().Unix()) +} + /* Test Helpers */ func expect(t *testing.T, a interface{}, b interface{}) { if a != b { @@ -42,7 +49,7 @@ func Test_InjectorInvoke(t *testing.T) { typSend := reflect.ChanOf(reflect.SendDir, reflect.TypeOf(dep4).Elem()) injector.Set(typRecv, reflect.ValueOf(dep3)) injector.Set(typSend, reflect.ValueOf(dep4)) - + _, err := injector.Invoke(func(d1 string, d2 SpecialString, d3 <-chan *SpecialString, d4 chan<- *SpecialString) { expect(t, d1, dep) expect(t, d2, dep2) @@ -104,15 +111,15 @@ func Test_InterfaceOf(t *testing.T) { func Test_InjectorSet(t *testing.T) { injector := inject.New() - typ := reflect.TypeOf("string") - typSend := reflect.ChanOf(reflect.SendDir, typ) - typRecv := reflect.ChanOf(reflect.RecvDir, typ) - + typ := reflect.TypeOf("string") + typSend := reflect.ChanOf(reflect.SendDir, typ) + typRecv := reflect.ChanOf(reflect.RecvDir, typ) + // instantiating unidirectional channels is not possible using reflect // http://golang.org/src/pkg/reflect/value.go?s=60463:60504#L2064 chanRecv := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, typ), 0) chanSend := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, typ), 0) - + injector.Set(typSend, chanSend) injector.Set(typRecv, chanRecv) @@ -121,7 +128,6 @@ func Test_InjectorSet(t *testing.T) { expect(t, injector.Get(chanSend.Type()).IsValid(), false) } - func Test_InjectorGet(t *testing.T) { injector := inject.New() @@ -140,3 +146,104 @@ func Test_InjectorSetParent(t *testing.T) { expect(t, injector2.Get(inject.InterfaceOf((*SpecialString)(nil))).IsValid(), true) } + +func Test_InjectorInvokeFactory(t *testing.T) { + injector := inject.New() + + dep := "some dependency" + injector.Map(func() string { + return dep + }) + dep2 := "another dep" + injector.MapTo(func() string { + return dep2 + }, (*SpecialString)(nil)) + + res, err := injector.Invoke(func(d1 string, d2 SpecialString) string { + expect(t, d1, dep) + expect(t, d2, dep2) + return dep + }) + + expect(t, err, nil) + expect(t, res[0].String(), dep) +} + +func Test_InjectorInvokeCascadingFactory(t *testing.T) { + injector := inject.New() + + answer := 42 + injector.Map(func() int { + return answer + }) + question := "What do you get if you multiply six by nine?" + injector.Map(func(answer int) string { + return fmt.Sprintf("%v %v", question, answer) + }) + + sentence := fmt.Sprintf("%v %v", question, answer) + res, err := injector.Invoke(func(d1 string) string { + expect(t, d1, sentence) + return sentence + }) + + expect(t, err, nil) + expect(t, res[0].String(), sentence) +} + +func Test_InjectorInvokeFactoryDependencyLoop(t *testing.T) { + injector := inject.New() + + dep := "some dependency" + injector.Map(func(d2 string) string { + return dep + }) + + _, err := injector.Invoke(func(d string) { + t.Errorf("expected an error, not %v", d) + }) + + if err == nil { + t.Errorf("expected an error") + } +} + +func Test_InjectorInvokeFactoryWithParentDependency(t *testing.T) { + injector := inject.New() + dep := "some dependency" + injector.Map(func(d2 int) string { + return dep + }) + + injector2 := inject.New() + injector2.Map(42) + injector2.SetParent(injector) + + res, err := injector2.Invoke(func(d1 string) string { + expect(t, d1, dep) + return dep + }) + + expect(t, err, nil) + expect(t, res[0].String(), dep) +} + +func Test_InjectorInvokeFactoryCaching(t *testing.T) { + injector := inject.New() + + injector.Map(func() int { + return rand.Intn(1000000) + }) + injector.MapTo(func() string { + return "!" + }, (*SpecialString)(nil)) + injector.Map(func(c SpecialString, n int) string { + return fmt.Sprintf("%v%v", n, c) + }) + + _, err := injector.Invoke(func(s string, c SpecialString, n int) { + expect(t, s, fmt.Sprintf("%v%v", n, c)) + }) + + expect(t, err, nil) +}