Skip to content

Commit fbbe4a5

Browse files
committed
feat: better performance for DI router
1 parent e95ca53 commit fbbe4a5

File tree

3 files changed

+160
-115
lines changed

3 files changed

+160
-115
lines changed

pkg/scaffold/dirouter.go

Lines changed: 25 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -4,98 +4,8 @@ import (
44
"fmt"
55
"net/http"
66
"reflect"
7-
8-
"github.com/nicksnyder/go-i18n/v2/i18n"
9-
10-
"github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user"
11-
"github.com/iota-uz/iota-sdk/pkg/application"
12-
"github.com/iota-uz/iota-sdk/pkg/composables"
13-
"github.com/iota-uz/iota-sdk/pkg/types"
147
)
158

16-
var builtinProviders = []Provider{
17-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
18-
pageCtxType := reflect.TypeOf((*types.PageContext)(nil))
19-
if t == pageCtxType {
20-
return reflect.ValueOf(composables.UsePageCtx(r.Context())), true, nil
21-
}
22-
return reflect.Value{}, false, nil
23-
},
24-
25-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
26-
writerType := reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
27-
if t.Implements(writerType) {
28-
return reflect.ValueOf(w), true, nil
29-
}
30-
return reflect.Value{}, false, nil
31-
},
32-
33-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
34-
requestType := reflect.TypeOf((*http.Request)(nil))
35-
if t == requestType {
36-
return reflect.ValueOf(r), true, nil
37-
}
38-
return reflect.Value{}, false, nil
39-
},
40-
41-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
42-
localizerType := reflect.TypeOf((*i18n.Localizer)(nil))
43-
if t == localizerType {
44-
localizer, ok := composables.UseLocalizer(r.Context())
45-
if !ok {
46-
return reflect.Value{}, true, fmt.Errorf("localizer not found in request context")
47-
}
48-
return reflect.ValueOf(localizer), true, nil
49-
}
50-
return reflect.Value{}, false, nil
51-
},
52-
53-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
54-
userType := reflect.TypeOf((*user.User)(nil)).Elem()
55-
if t.Implements(userType) {
56-
u, err := composables.UseUser(r.Context())
57-
if err != nil {
58-
return reflect.Value{}, true, fmt.Errorf("user not found in request context")
59-
}
60-
return reflect.ValueOf(u), true, nil
61-
}
62-
return reflect.Value{}, false, nil
63-
},
64-
65-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
66-
appType := reflect.TypeOf((*application.Application)(nil)).Elem()
67-
if t.Implements(appType) {
68-
app, err := composables.UseApp(r.Context())
69-
if err != nil {
70-
return reflect.Value{}, true, err
71-
}
72-
return reflect.ValueOf(app), true, nil
73-
}
74-
return reflect.Value{}, false, nil
75-
},
76-
77-
func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error) {
78-
app, err := composables.UseApp(r.Context())
79-
if err != nil {
80-
return reflect.Value{}, false, err
81-
}
82-
83-
services := app.Services()
84-
if service, exists := services[t.Elem()]; exists {
85-
return reflect.ValueOf(service), true, nil
86-
}
87-
88-
return reflect.Value{}, false, nil
89-
},
90-
}
91-
92-
// Provider is a function that can provide a value for a given type
93-
// It returns:
94-
// - The value (if it can provide it)
95-
// - A boolean indicating whether this provider can handle the requested type
96-
// - Any error that occurred during resolution
97-
type Provider func(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, bool, error)
98-
999
func NewDIHandler(handler interface{}, customProviders ...Provider) *DIHandler {
10010
return &DIHandler{
10111
value: reflect.ValueOf(handler),
@@ -113,39 +23,44 @@ func (d *DIHandler) Handler() http.HandlerFunc {
11323
typeOf := d.value.Type()
11424
numArgs := typeOf.NumIn()
11525

116-
// Precompute argument types and their resolvers
11726
argTypes := make([]reflect.Type, numArgs)
11827
for i := 0; i < numArgs; i++ {
11928
argTypes[i] = typeOf.In(i)
12029
}
12130

12231
// All providers to try in order (custom first, then built-in)
123-
allProviders := append(d.customProviders, builtinProviders...)
32+
allProviders := append(d.customProviders, BuiltinProviders()...)
33+
34+
matchedProviders := make([]Provider, numArgs)
35+
for i, argType := range argTypes {
36+
for _, provider := range allProviders {
37+
if provider.Ok(argType) {
38+
matchedProviders[i] = provider
39+
break
40+
}
41+
}
42+
43+
if matchedProviders[i] == nil {
44+
// Return a handler that will return an error for this specific type
45+
errorMsg := fmt.Sprintf("No provider found for type: %v", argType)
46+
return func(w http.ResponseWriter, r *http.Request) {
47+
http.Error(w, errorMsg, http.StatusInternalServerError)
48+
}
49+
}
50+
}
12451

12552
return func(w http.ResponseWriter, r *http.Request) {
12653
args := make([]reflect.Value, numArgs)
12754

128-
// Resolve each argument
55+
// Resolve each argument using precomputed matched providers
12956
for i, argType := range argTypes {
130-
var resolved bool
131-
var err error
132-
133-
// Try each provider in order
134-
for _, provider := range allProviders {
135-
args[i], resolved, err = provider(argType, w, r)
136-
if err != nil {
137-
http.Error(w, err.Error(), http.StatusInternalServerError)
138-
return
139-
}
140-
if resolved {
141-
break
142-
}
143-
}
144-
145-
if !resolved {
146-
http.Error(w, fmt.Sprintf("No provider found for type: %v", argType), http.StatusInternalServerError)
57+
provider := matchedProviders[i]
58+
value, err := provider.Provide(argType, w, r)
59+
if err != nil {
60+
http.Error(w, err.Error(), http.StatusInternalServerError)
14761
return
14862
}
63+
args[i] = value
14964
}
15065

15166
d.value.Call(args)

pkg/scaffold/dirouter_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"net/http"
77
"net/http/httptest"
8-
"reflect"
98
"testing"
109

1110
"github.com/nicksnyder/go-i18n/v2/i18n"
@@ -119,14 +118,14 @@ func BenchmarkDIRouter(b *testing.B) {
119118
req, _ := http.NewRequest("GET", "/123", nil)
120119
req = req.WithContext(ctx)
121120

122-
handler := &DIHandler{
123-
value: reflect.ValueOf(diTestHandler),
124-
}
121+
// Create a new handler for each run to ensure setup time is not included in the benchmark
122+
handler := NewDIHandler(diTestHandler)
123+
handlerFunc := handler.Handler() // Pre-compute the handler function
125124

126125
b.ResetTimer()
127126
for i := 0; i < b.N; i++ {
128127
rr := httptest.NewRecorder()
129-
handler.Handler()(rr, req)
128+
handlerFunc(rr, req)
130129
}
131130
}
132131

pkg/scaffold/providers.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package scaffold
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"reflect"
7+
8+
"github.com/nicksnyder/go-i18n/v2/i18n"
9+
10+
"github.com/iota-uz/iota-sdk/modules/core/domain/aggregates/user"
11+
"github.com/iota-uz/iota-sdk/pkg/application"
12+
"github.com/iota-uz/iota-sdk/pkg/composables"
13+
"github.com/iota-uz/iota-sdk/pkg/types"
14+
)
15+
16+
// Provider is an interface that can provide a value for a given type
17+
type Provider interface {
18+
// Ok checks if this provider can handle the requested type
19+
Ok(t reflect.Type) bool
20+
21+
// Provide returns the value for the given type
22+
// Should only be called if Ok returns true
23+
Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error)
24+
}
25+
26+
// Define provider types for each built-in provider
27+
type pageContextProvider struct{}
28+
type httpWriterProvider struct{}
29+
type httpRequestProvider struct{}
30+
type localizerProvider struct{}
31+
type userProvider struct{}
32+
type appProvider struct{}
33+
type serviceProvider struct{}
34+
35+
func (p *pageContextProvider) Ok(t reflect.Type) bool {
36+
pageCtxType := reflect.TypeOf((*types.PageContext)(nil))
37+
return t == pageCtxType
38+
}
39+
40+
func (p *pageContextProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
41+
return reflect.ValueOf(composables.UsePageCtx(r.Context())), nil
42+
}
43+
44+
func (p *httpWriterProvider) Ok(t reflect.Type) bool {
45+
writerType := reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
46+
return t.Implements(writerType)
47+
}
48+
49+
func (p *httpWriterProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
50+
return reflect.ValueOf(w), nil
51+
}
52+
53+
func (p *httpRequestProvider) Ok(t reflect.Type) bool {
54+
requestType := reflect.TypeOf((*http.Request)(nil))
55+
return t == requestType
56+
}
57+
58+
func (p *httpRequestProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
59+
return reflect.ValueOf(r), nil
60+
}
61+
62+
func (p *localizerProvider) Ok(t reflect.Type) bool {
63+
localizerType := reflect.TypeOf((*i18n.Localizer)(nil))
64+
return t == localizerType
65+
}
66+
67+
func (p *localizerProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
68+
localizer, ok := composables.UseLocalizer(r.Context())
69+
if !ok {
70+
return reflect.Value{}, fmt.Errorf("localizer not found in request context")
71+
}
72+
return reflect.ValueOf(localizer), nil
73+
}
74+
75+
func (p *userProvider) Ok(t reflect.Type) bool {
76+
userType := reflect.TypeOf((*user.User)(nil)).Elem()
77+
return t.Implements(userType)
78+
}
79+
80+
func (p *userProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
81+
u, err := composables.UseUser(r.Context())
82+
if err != nil {
83+
return reflect.Value{}, fmt.Errorf("user not found in request context")
84+
}
85+
return reflect.ValueOf(u), nil
86+
}
87+
88+
func (p *appProvider) Ok(t reflect.Type) bool {
89+
appType := reflect.TypeOf((*application.Application)(nil)).Elem()
90+
return t.Implements(appType)
91+
}
92+
93+
func (p *appProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
94+
app, err := composables.UseApp(r.Context())
95+
if err != nil {
96+
return reflect.Value{}, err
97+
}
98+
return reflect.ValueOf(app), nil
99+
}
100+
101+
func (p *serviceProvider) Ok(t reflect.Type) bool {
102+
// Basic check: must be a pointer type for services
103+
return t.Kind() == reflect.Ptr
104+
}
105+
106+
func (p *serviceProvider) Provide(t reflect.Type, w http.ResponseWriter, r *http.Request) (reflect.Value, error) {
107+
app, err := composables.UseApp(r.Context())
108+
if err != nil {
109+
return reflect.Value{}, err
110+
}
111+
112+
services := app.Services()
113+
if service, exists := services[t.Elem()]; exists {
114+
return reflect.ValueOf(service), nil
115+
}
116+
117+
return reflect.Value{}, fmt.Errorf("service not found for type: %v", t)
118+
}
119+
120+
// BuiltinProviders returns the list of built-in providers
121+
func BuiltinProviders() []Provider {
122+
return []Provider{
123+
&pageContextProvider{},
124+
&httpWriterProvider{},
125+
&httpRequestProvider{},
126+
&localizerProvider{},
127+
&userProvider{},
128+
&appProvider{},
129+
&serviceProvider{},
130+
}
131+
}

0 commit comments

Comments
 (0)