Skip to content

Commit 6465c16

Browse files
authored
Simplify chain interceptors (grpc-ecosystem#421)
The four (client/server, unary/stream) interceptors have to wrap a slice of interceptors in functions which satisfy the handler interface, but which are closures over the other parameters an interceptor is expected to have. The previous approach accomplished this goal with recursion. This had two drawbacks: first, the code was difficult to understand, as most recursion attempts to fully encode state in function parameters, but here state was necessarily also encoded in the closure; and second, the recursive base-case meant that even the innermost interceptor was not calling the bare handler, it was calling a wrapped handler. This new approach instead iteratively constructs wrappers from the inside out. It results in fewer lines of code, with fewer variables held in each closure. Hopefully this results in higher readability.
1 parent 1778d41 commit 6465c16

File tree

1 file changed

+78
-40
lines changed

1 file changed

+78
-40
lines changed

chain.go

Lines changed: 78 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,36 @@ import (
2121
func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
2222
n := len(interceptors)
2323

24-
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
25-
chainer := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler {
26-
return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
27-
return currentInter(currentCtx, currentReq, info, currentHandler)
28-
}
24+
// Dummy interceptor maintained for backward compatibility to avoid returning nil.
25+
if n == 0 {
26+
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
27+
return handler(ctx, req)
2928
}
29+
}
3030

31-
chainedHandler := handler
32-
for i := n - 1; i >= 0; i-- {
33-
chainedHandler = chainer(interceptors[i], chainedHandler)
34-
}
31+
// The degenerate case, just return the single wrapped interceptor directly.
32+
if n == 1 {
33+
return interceptors[0]
34+
}
3535

36-
return chainedHandler(ctx, req)
36+
// Return a function which satisfies the interceptor interface, and which is
37+
// a closure over the given list of interceptors to be chained.
38+
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
39+
currHandler := handler
40+
// Iterate backwards through all interceptors except the first (outermost).
41+
// Wrap each one in a function which satisfies the handler interface, but
42+
// is also a closure over the `info` and `handler` parameters. Then pass
43+
// each pseudo-handler to the next outer interceptor as the handler to be called.
44+
for i := n - 1; i > 0; i-- {
45+
// Rebind to loop-local vars so they can be closed over.
46+
innerHandler, i := currHandler, i
47+
currHandler = func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
48+
return interceptors[i](currentCtx, currentReq, info, innerHandler)
49+
}
50+
}
51+
// Finally return the result of calling the outermost interceptor with the
52+
// outermost pseudo-handler created above as its handler.
53+
return interceptors[0](ctx, req, info, currHandler)
3754
}
3855
}
3956

@@ -47,19 +64,26 @@ func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnarySer
4764
func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
4865
n := len(interceptors)
4966

50-
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
51-
chainer := func(currentInter grpc.StreamServerInterceptor, currentHandler grpc.StreamHandler) grpc.StreamHandler {
52-
return func(currentSrv interface{}, currentStream grpc.ServerStream) error {
53-
return currentInter(currentSrv, currentStream, info, currentHandler)
54-
}
67+
// Dummy interceptor maintained for backward compatibility to avoid returning nil.
68+
if n == 0 {
69+
return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
70+
return handler(srv, stream)
5571
}
72+
}
5673

57-
chainedHandler := handler
58-
for i := n - 1; i >= 0; i-- {
59-
chainedHandler = chainer(interceptors[i], chainedHandler)
60-
}
74+
if n == 1 {
75+
return interceptors[0]
76+
}
6177

62-
return chainedHandler(srv, ss)
78+
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
79+
currHandler := handler
80+
for i := n - 1; i > 0; i-- {
81+
innerHandler, i := currHandler, i
82+
currHandler = func(currentSrv interface{}, currentStream grpc.ServerStream) error {
83+
return interceptors[i](currentSrv, currentStream, info, innerHandler)
84+
}
85+
}
86+
return interceptors[0](srv, stream, info, currHandler)
6387
}
6488
}
6589

@@ -70,19 +94,26 @@ func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.Stream
7094
func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
7195
n := len(interceptors)
7296

73-
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
74-
chainer := func(currentInter grpc.UnaryClientInterceptor, currentInvoker grpc.UnaryInvoker) grpc.UnaryInvoker {
75-
return func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
76-
return currentInter(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentInvoker, currentOpts...)
77-
}
97+
// Dummy interceptor maintained for backward compatibility to avoid returning nil.
98+
if n == 0 {
99+
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
100+
return invoker(ctx, method, req, reply, cc, opts...)
78101
}
102+
}
79103

80-
chainedInvoker := invoker
81-
for i := n - 1; i >= 0; i-- {
82-
chainedInvoker = chainer(interceptors[i], chainedInvoker)
83-
}
104+
if n == 1 {
105+
return interceptors[0]
106+
}
84107

85-
return chainedInvoker(ctx, method, req, reply, cc, opts...)
108+
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
109+
currInvoker := invoker
110+
for i := n - 1; i > 0; i-- {
111+
innerInvoker, i := currInvoker, i
112+
currInvoker = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
113+
return interceptors[i](currentCtx, currentMethod, currentReq, currentRepl, currentConn, innerInvoker, currentOpts...)
114+
}
115+
}
116+
return interceptors[0](ctx, method, req, reply, cc, currInvoker, opts...)
86117
}
87118
}
88119

@@ -93,19 +124,26 @@ func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryCli
93124
func ChainStreamClient(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
94125
n := len(interceptors)
95126

96-
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
97-
chainer := func(currentInter grpc.StreamClientInterceptor, currentStreamer grpc.Streamer) grpc.Streamer {
98-
return func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) {
99-
return currentInter(currentCtx, currentDesc, currentConn, currentMethod, currentStreamer, currentOpts...)
100-
}
127+
// Dummy interceptor maintained for backward compatibility to avoid returning nil.
128+
if n == 0 {
129+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
130+
return streamer(ctx, desc, cc, method, opts...)
101131
}
132+
}
102133

103-
chainedStreamer := streamer
104-
for i := n - 1; i >= 0; i-- {
105-
chainedStreamer = chainer(interceptors[i], chainedStreamer)
106-
}
134+
if n == 1 {
135+
return interceptors[0]
136+
}
107137

108-
return chainedStreamer(ctx, desc, cc, method, opts...)
138+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
139+
currStreamer := streamer
140+
for i := n - 1; i > 0; i-- {
141+
innerStreamer, i := currStreamer, i
142+
currStreamer = func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) {
143+
return interceptors[i](currentCtx, currentDesc, currentConn, currentMethod, innerStreamer, currentOpts...)
144+
}
145+
}
146+
return interceptors[0](ctx, desc, cc, method, currStreamer, opts...)
109147
}
110148
}
111149

0 commit comments

Comments
 (0)