Skip to content

added simple dependency provider #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 84 additions & 22 deletions inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ type TypeMapper interface {
Set(reflect.Type, reflect.Value) TypeMapper
// Returns the Value that is mapped to the current type. Returns a zeroed Value if
// the Type has not been mapped.
Get(reflect.Type) reflect.Value
Get(reflect.Type, ...bool) reflect.Value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to have another method instead of passing a bool to the Get function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, off the top of my head: yes, it might. Maybe GetRoot(reflect.Type) reflect.Value.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a subtle hint for me to make this change? I can't handle subtle ;)

}

type injector struct {
factories map[reflect.Type]reflect.Value
values map[reflect.Type]reflect.Value
parent Injector
}
Expand All @@ -77,6 +78,7 @@ func InterfaceOf(value interface{}) reflect.Type {
func New() Injector {
return &injector{
values: make(map[reflect.Type]reflect.Value),
factories: make(map[reflect.Type]reflect.Value),
}
}

Expand All @@ -85,27 +87,31 @@ func New() Injector {
// Returns a slice of reflect.Value representing the returned values of the function.
// Returns an error if the injection fails.
// It panics if f is not a function
func (inj *injector) Invoke(f interface{}) ([]reflect.Value, error) {
t := reflect.TypeOf(f)
func (inj *injector) Invoke(f interface{}) (ret []reflect.Value, err error) {
defer recoverResolvePanic(&err)

var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func
for i := 0; i < t.NumIn(); i++ {
t := reflect.TypeOf(f)
var args = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func
for i, arg := range args {
argType := t.In(i)
val := inj.Get(argType)
if !val.IsValid() {
arg = inj.Get(argType)
if !arg.IsValid() {
return nil, fmt.Errorf("Value not found for type %v", argType)
}

in[i] = val
args[i] = arg
}

return reflect.ValueOf(f).Call(in), nil
ret = reflect.ValueOf(f).Call(args)
return
}

// Maps dependencies in the Type map to each field in the struct
// that is tagged with 'inject'.
// Returns an error if the injection fails.
func (inj *injector) Apply(val interface{}) error {
func (inj *injector) Apply(val interface{}) (err error) {
defer recoverResolvePanic(&err)

v := reflect.ValueOf(val)

for v.Kind() == reflect.Ptr {
Expand Down Expand Up @@ -133,19 +139,17 @@ func (inj *injector) Apply(val interface{}) error {

}

return nil
return
}

// Maps the concrete value of val to its dynamic type using reflect.TypeOf,
// It returns the TypeMapper registered in.
func (i *injector) Map(val interface{}) TypeMapper {
i.values[reflect.TypeOf(val)] = reflect.ValueOf(val)
return i
return i.mapping(reflect.TypeOf(val), reflect.ValueOf(val))
}

func (i *injector) MapTo(val interface{}, ifacePtr interface{}) TypeMapper {
i.values[InterfaceOf(ifacePtr)] = reflect.ValueOf(val)
return i
return i.mapping(InterfaceOf(ifacePtr), reflect.ValueOf(val))
}

// Maps the given reflect.Type to the given reflect.Value and returns
Expand All @@ -155,14 +159,72 @@ func (i *injector) Set(typ reflect.Type, val reflect.Value) TypeMapper {
return i
}

func (i *injector) Get(t reflect.Type) reflect.Value {
val := i.values[t]
if !val.IsValid() && i.parent != nil {
val = i.parent.Get(t)
func (inj *injector) mapping(typ reflect.Type, val reflect.Value) TypeMapper {
kind := val.Kind()
if kind == reflect.Func {
if typ.Kind() == reflect.Func {
typ = val.Type().Out(0)
}
inj.factories[typ] = val
} else {
inj.values[typ] = val
}
return val
return inj
}

func (inj *injector) Get(want reflect.Type, root ...bool) (ret reflect.Value) {
ret = inj.values[want]

if !ret.IsValid() {
ret = inj.factories[want]
}

if !ret.IsValid() && inj.parent != nil {
ret = inj.parent.Get(want, false)
}

if ret.IsValid() && ret.Kind() == reflect.Func && (len(root) == 0 || root[0]) {
ret = inj.resolve(ret, []reflect.Type{})
}

return
}

func (inj *injector) resolve(fac reflect.Value, chain []reflect.Type) reflect.Value {
chainLen := len(chain)
if chainLen > 1 {
want := chain[chainLen - 1]
for _, t := range chain[:chainLen - 1] {
if want == t {
panic(resolveError{chain: chain, message: "dependency loop"})
}
}
}

facType := fac.Type()
args := make([]reflect.Value, facType.NumIn())
for i, _ := range args {
argType := facType.In(i)

if cachedVal, ok := inj.values[argType]; ok {
args[i] = cachedVal
continue
}

args[i] = inj.Get(argType, false)
if !args[i].IsValid() {
panic(resolveError{chain: chain, fac: fac})
}

if args[i].Kind() == reflect.Func {
args[i] = inj.resolve(args[i], append(chain, facType, argType))
}
inj.values[argType] = args[i]
}

return fac.Call(args)[0]
}

func (i *injector) SetParent(parent Injector) {
i.parent = parent
func (inj *injector) SetParent(parent Injector) {
inj.parent = parent
}
42 changes: 42 additions & 0 deletions inject_err.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package inject

import (
"fmt"
"reflect"
"strings"
)

type resolveError struct {
chain []reflect.Type
fac reflect.Value
message string
}

func recoverResolvePanic(err *error) {
if r := recover(); r != nil {
switch x := r.(type) {
case resolveError:
chain := make([]string, len(x.chain))
for i, t := range x.chain {
chain[i] = fmt.Sprintf("%q", t)
}

if x.message == "" && x.fac.IsValid() {
m := "factory 'func("
facType := x.fac.Type()
for i := 0; i < facType.NumIn(); {
m += fmt.Sprintf("%v", facType.In(i))
i += 1
}
m += ") " + fmt.Sprintf("%v", facType.Out(0)) + "'"
x.message = m
}

*err = fmt.Errorf("Value not found for type %v (%v): %v",
chain[len(chain) - 1], x.message, strings.Join(chain, " -> "))
println(fmt.Sprintf("%v", *err))
default:
*err = fmt.Errorf("%v", x)
}
}
}
108 changes: 108 additions & 0 deletions inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package inject_test

import (
"github.com/codegangsta/inject"
"fmt"
"time"
"reflect"
"testing"
"math/rand"
)

type SpecialString interface {
Expand All @@ -15,6 +18,10 @@ type TestStruct struct {
Dep3 string
}

func init() {
rand.Seed(time.Now().Unix())
}

/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
Expand Down Expand Up @@ -140,3 +147,104 @@ func Test_InjectorSetParent(t *testing.T) {

expect(t, injector2.Get(inject.InterfaceOf((*SpecialString)(nil))).IsValid(), true)
}

func Test_InjectorInvokeFactory(t *testing.T) {
injector := inject.New()

dep := "some dependency"
injector.Map(func() string {
return dep
})
dep2 := "another dep"
injector.MapTo(func() string {
return dep2
}, (*SpecialString)(nil))

res, err := injector.Invoke(func(d1 string, d2 SpecialString) string {
expect(t, d1, dep)
expect(t, d2, dep2)
return dep
})

expect(t, err, nil)
expect(t, res[0].String(), dep)
}

func Test_InjectorInvokeCascadingFactory(t *testing.T) {
injector := inject.New()

answer := 42
injector.Map(func() int {
return answer
})
question := "What do you get if you multiply six by nine?"
injector.Map(func(answer int) string {
return fmt.Sprintf("%v %v", question, answer)
})

sentence := fmt.Sprintf("%v %v", question, answer)
res, err := injector.Invoke(func(d1 string) string {
expect(t, d1, sentence)
return sentence
})

expect(t, err, nil)
expect(t, res[0].String(), sentence)
}

func Test_InjectorInvokeFactoryDependencyLoop(t *testing.T) {
injector := inject.New()

dep := "some dependency"
injector.Map(func(d2 string) string {
return dep
})

_, err := injector.Invoke(func(d string) {
t.Errorf("expected an error, not %v", d)
})

if err == nil {
t.Errorf("expected an error")
}
}

func Test_InjectorInvokeFactoryWithParentDependency(t *testing.T) {
injector := inject.New()
dep := "some dependency"
injector.Map(func(d2 int) string {
return dep
})

injector2 := inject.New()
injector2.Map(42)
injector2.SetParent(injector)

res, err := injector2.Invoke(func(d1 string) string {
expect(t, d1, dep)
return dep
})

expect(t, err, nil)
expect(t, res[0].String(), dep)
}

func Test_InjectorInvokeFactoryCaching(t *testing.T) {
injector := inject.New()

injector.Map(func() int {
return rand.Intn(1000000)
})
injector.MapTo(func() string {
return "!"
}, (*SpecialString)(nil))
injector.Map(func(c SpecialString, n int) string {
return fmt.Sprintf("%v%v", n, c)
})

_, err := injector.Invoke(func(s string, c SpecialString, n int) {
expect(t, s, fmt.Sprintf("%v%v", n, c))
})

expect(t, err, nil)
}