Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e13869d

Browse files
committedAug 19, 2024·
feat(codec): Unknown Service Handler (cloudwego#1321)
1 parent 0824d3c commit e13869d

File tree

6 files changed

+469
-0
lines changed

6 files changed

+469
-0
lines changed
 

‎pkg/remote/option.go

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"net"
2222
"time"
2323

24+
"github.com/cloudwego/kitex/pkg/unknownservice/service"
25+
2426
"github.com/cloudwego/kitex/pkg/endpoint"
2527
"github.com/cloudwego/kitex/pkg/profiler"
2628
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
@@ -113,6 +115,8 @@ type ServerOption struct {
113115

114116
GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error
115117

118+
UnknownServiceHandler service.UnknownServiceHandler
119+
116120
Option
117121

118122
// invoking chain with recv/send middlewares for streaming APIs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2021 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package service
18+
19+
import (
20+
"context"
21+
22+
"github.com/cloudwego/kitex/pkg/serviceinfo"
23+
)
24+
25+
const (
26+
// UnknownService name
27+
UnknownService = "$UnknownService" // private as "$"
28+
// UnknownMethod name
29+
UnknownMethod = "$UnknownMethod"
30+
)
31+
32+
type Args struct {
33+
Request []byte
34+
Method string
35+
ServiceName string
36+
}
37+
38+
type Result struct {
39+
Success []byte
40+
Method string
41+
ServiceName string
42+
}
43+
44+
type UnknownServiceHandler interface {
45+
UnknownServiceHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error)
46+
}
47+
48+
// NewServiceInfo create serviceInfo
49+
func NewServiceInfo(pcType serviceinfo.PayloadCodec, service, method string) *serviceinfo.ServiceInfo {
50+
methods := map[string]serviceinfo.MethodInfo{
51+
method: serviceinfo.NewMethodInfo(callHandler, newServiceArgs, newServiceResult, false),
52+
}
53+
handlerType := (*UnknownServiceHandler)(nil)
54+
55+
svcInfo := &serviceinfo.ServiceInfo{
56+
ServiceName: service,
57+
HandlerType: handlerType,
58+
Methods: methods,
59+
PayloadCodec: pcType,
60+
Extra: make(map[string]interface{}),
61+
}
62+
63+
return svcInfo
64+
}
65+
66+
func callHandler(ctx context.Context, handler, arg, result interface{}) error {
67+
realArg := arg.(*Args)
68+
realResult := result.(*Result)
69+
realResult.Method = realArg.Method
70+
realResult.ServiceName = realArg.ServiceName
71+
success, err := handler.(UnknownServiceHandler).UnknownServiceHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request)
72+
if err != nil {
73+
return err
74+
}
75+
realResult.Success = success
76+
return nil
77+
}
78+
79+
func newServiceArgs() interface{} {
80+
return &Args{}
81+
}
82+
83+
func newServiceResult() interface{} {
84+
return &Result{}
85+
}

‎pkg/unknownservice/unknown.go

+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
* Copyright 2021 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package unknownservice
18+
19+
import (
20+
"context"
21+
"encoding/binary"
22+
"errors"
23+
"fmt"
24+
25+
"github.com/cloudwego/kitex/pkg/protocol/bthrift"
26+
thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache"
27+
28+
"github.com/cloudwego/kitex/pkg/remote"
29+
"github.com/cloudwego/kitex/pkg/remote/codec"
30+
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
31+
"github.com/cloudwego/kitex/pkg/rpcinfo"
32+
"github.com/cloudwego/kitex/pkg/serviceinfo"
33+
unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service"
34+
)
35+
36+
// UnknownCodec implements PayloadCodec
37+
type unknownCodec struct {
38+
Codec remote.PayloadCodec
39+
}
40+
41+
// NewUnknownServiceCodec creates the unknown binary codec.
42+
func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec {
43+
return &unknownCodec{code}
44+
}
45+
46+
// Marshal implements the remote.PayloadCodec interface.
47+
func (c unknownCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
48+
ink := msg.RPCInfo().Invocation()
49+
data := msg.Data()
50+
51+
res, ok := data.(*unknownservice.Result)
52+
if !ok {
53+
return c.Codec.Marshal(ctx, msg, out)
54+
}
55+
if len(res.Success) == 0 {
56+
return errors.New("unknown messages cannot be empty")
57+
}
58+
if msg.MessageType() == remote.Exception {
59+
return c.Codec.Marshal(ctx, msg, out)
60+
}
61+
if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
62+
ink.SetMethodName(res.Method)
63+
ink.SetServiceName(res.ServiceName)
64+
} else {
65+
return errors.New("the interface Invocation doesn't implement InvocationSetter")
66+
}
67+
if err := encode(res, msg, out); err != nil {
68+
return c.Codec.Marshal(ctx, msg, out)
69+
}
70+
return nil
71+
}
72+
73+
// Unmarshal implements the remote.PayloadCodec interface.
74+
func (c unknownCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
75+
ink := message.RPCInfo().Invocation()
76+
magicAndMsgType, err := codec.PeekUint32(in)
77+
if err != nil {
78+
return err
79+
}
80+
msgType := magicAndMsgType & codec.FrontMask
81+
if msgType == uint32(remote.Exception) {
82+
return c.Codec.Unmarshal(ctx, message, in)
83+
}
84+
if err = codec.UpdateMsgType(msgType, message); err != nil {
85+
return err
86+
}
87+
service, method, err := readDecode(message, in)
88+
if err != nil {
89+
return err
90+
}
91+
err = codec.SetOrCheckMethodName(method, message)
92+
var te *remote.TransError
93+
if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) {
94+
svcInfo, err := message.SpecifyServiceInfo(unknownservice.UnknownService, unknownservice.UnknownMethod)
95+
if err != nil {
96+
return err
97+
}
98+
99+
if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
100+
ink.SetMethodName(unknownservice.UnknownMethod)
101+
ink.SetPackageName(svcInfo.GetPackageName())
102+
ink.SetServiceName(unknownservice.UnknownService)
103+
} else {
104+
return errors.New("the interface Invocation doesn't implement InvocationSetter")
105+
}
106+
if err = codec.NewDataIfNeeded(unknownservice.UnknownMethod, message); err != nil {
107+
return err
108+
}
109+
110+
data := message.Data()
111+
112+
if data, ok := data.(*unknownservice.Args); ok {
113+
data.Method = method
114+
data.ServiceName = service
115+
buf, err := in.Next(in.ReadableLen())
116+
if err != nil {
117+
return err
118+
}
119+
data.Request = buf
120+
}
121+
return nil
122+
}
123+
124+
return c.Codec.Unmarshal(ctx, message, in)
125+
}
126+
127+
// Name implements the remote.PayloadCodec interface.
128+
func (c unknownCodec) Name() string {
129+
return "unknownMethodCodec"
130+
}
131+
132+
func write(dst, src []byte) {
133+
copy(dst, src)
134+
}
135+
136+
func readDecode(message remote.Message, in remote.ByteBuffer) (string, string, error) {
137+
code := message.ProtocolInfo().CodecType
138+
if code == serviceinfo.Thrift || code == serviceinfo.Protobuf {
139+
method, size, err := peekMethod(in)
140+
if err != nil {
141+
return "", "", err
142+
}
143+
144+
seqID, err := peekSeqID(in, size)
145+
if err != nil {
146+
return "", "", err
147+
}
148+
if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
149+
return "", "", err
150+
}
151+
return message.RPCInfo().Invocation().ServiceName(), method, nil
152+
}
153+
return "", "", nil
154+
}
155+
156+
func peekMethod(in remote.ByteBuffer) (string, int32, error) {
157+
buf, err := in.Peek(8)
158+
if err != nil {
159+
return "", 0, err
160+
}
161+
buf = buf[4:]
162+
size := int32(binary.BigEndian.Uint32(buf))
163+
buf, err = in.Peek(int(size + 8))
164+
if err != nil {
165+
return "", 0, perrors.NewProtocolError(err)
166+
}
167+
buf = buf[8:]
168+
method := string(buf)
169+
return method, size + 8, nil
170+
}
171+
172+
func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) {
173+
buf, err := in.Peek(int(size + 4))
174+
if err != nil {
175+
return 0, perrors.NewProtocolError(err)
176+
}
177+
buf = buf[size:]
178+
seqID := int32(binary.BigEndian.Uint32(buf))
179+
return seqID, nil
180+
}
181+
182+
func encode(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error {
183+
if msg.ProtocolInfo().CodecType == serviceinfo.Thrift {
184+
return encodeThrift(res, msg, out)
185+
}
186+
if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf {
187+
return encodeKitexProtobuf(res, msg, out)
188+
}
189+
return nil
190+
}
191+
192+
// encodeThrift Thrift encoder
193+
func encodeThrift(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error {
194+
nw, _ := out.(remote.NocopyWrite)
195+
msgType := msg.MessageType()
196+
ink := msg.RPCInfo().Invocation()
197+
msgBeginLen := bthrift.Binary.MessageBeginLength(res.Method, thrift.TMessageType(msgType), ink.SeqID())
198+
msgEndLen := bthrift.Binary.MessageEndLength()
199+
200+
buf, err := out.Malloc(msgBeginLen + len(res.Success) + msgEndLen)
201+
if err != nil {
202+
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error()))
203+
}
204+
offset := bthrift.Binary.WriteMessageBegin(buf, res.Method, thrift.TMessageType(msgType), ink.SeqID())
205+
write(buf[offset:], res.Success)
206+
bthrift.Binary.WriteMessageEnd(buf[offset:])
207+
if nw == nil {
208+
// if nw is nil, FastWrite will act in Copy mode.
209+
return nil
210+
}
211+
return nw.MallocAck(out.MallocLen())
212+
}
213+
214+
// encodeKitexProtobuf Kitex Protobuf encoder
215+
func encodeKitexProtobuf(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error {
216+
ink := msg.RPCInfo().Invocation()
217+
// 3.1 magic && msgType
218+
if err := codec.WriteUint32(codec.ProtobufV1Magic+uint32(msg.MessageType()), out); err != nil {
219+
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write meta info failed: %s", err.Error()))
220+
}
221+
// 3.2 methodName
222+
if _, err := codec.WriteString(res.Method, out); err != nil {
223+
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write method name failed: %s", err.Error()))
224+
}
225+
// 3.3 seqID
226+
if err := codec.WriteUint32(uint32(ink.SeqID()), out); err != nil {
227+
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write seqID failed: %s", err.Error()))
228+
}
229+
dataLen := len(res.Success)
230+
buf, err := out.Malloc(dataLen)
231+
if err != nil {
232+
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf malloc size %d failed: %s", dataLen, err.Error()))
233+
}
234+
write(buf, res.Success)
235+
return nil
236+
}

‎pkg/unknownservice/unknown_test.go

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2021 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package unknownservice
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/cloudwego/netpoll"
24+
25+
"github.com/cloudwego/kitex/internal/mocks"
26+
mt "github.com/cloudwego/kitex/internal/mocks/thrift"
27+
"github.com/cloudwego/kitex/internal/test"
28+
"github.com/cloudwego/kitex/pkg/remote"
29+
"github.com/cloudwego/kitex/pkg/remote/codec/thrift"
30+
netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
31+
"github.com/cloudwego/kitex/pkg/rpcinfo"
32+
"github.com/cloudwego/kitex/pkg/unknownservice/service"
33+
"github.com/cloudwego/kitex/transport"
34+
)
35+
36+
var (
37+
thr = thrift.NewThriftCodec()
38+
payloadCodec = unknownCodec{thr}
39+
svcInfo = mocks.ServiceInfo()
40+
)
41+
42+
func TestNormal(t *testing.T) {
43+
sendMsg := initSendMsg(transport.TTHeader)
44+
buf := netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024))
45+
ctx := context.Background()
46+
err := payloadCodec.Marshal(ctx, sendMsg, buf)
47+
test.Assert(t, err == nil, err)
48+
err = buf.Flush()
49+
test.Assert(t, err == nil, err)
50+
recvMsg := initRecvMsg()
51+
recvMsg.SetPayloadLen(buf.ReadableLen())
52+
_, size, err := peekMethod(buf)
53+
test.Assert(t, err == nil, err)
54+
err = payloadCodec.Unmarshal(ctx, recvMsg, buf)
55+
test.Assert(t, err == nil, err)
56+
57+
req := (sendMsg.Data()).(*service.Result).Success
58+
resp := (recvMsg.Data()).(*service.Args).Request
59+
resp = resp[size+4:]
60+
for i, item := range req {
61+
test.Assert(t, item == resp[i])
62+
}
63+
var _req mt.MockTestArgs
64+
var _resp mt.MockTestArgs
65+
reqMsg, err := _req.FastRead(req)
66+
test.Assert(t, err == nil, err)
67+
respMsg, err := _resp.FastRead(resp)
68+
test.Assert(t, err == nil && reqMsg == respMsg, err)
69+
test.Assert(t, len(_req.Req.StrList) == len(_resp.Req.StrList))
70+
test.Assert(t, len(_req.Req.StrMap) == len(_resp.Req.StrList))
71+
for i, item := range _req.Req.StrList {
72+
test.Assert(t, item == _resp.Req.StrList[i])
73+
}
74+
for k := range _resp.Req.StrMap {
75+
test.Assert(t, _req.Req.StrMap[k] == _resp.Req.StrMap[k])
76+
}
77+
}
78+
79+
func initSendMsg(tp transport.Protocol) remote.Message {
80+
var _args mt.MockTestArgs
81+
_args.Req = prepareReq()
82+
length := _args.BLength()
83+
bytes := make([]byte, length)
84+
_args.FastWriteNocopy(bytes, nil)
85+
arg := service.Result{Success: bytes, Method: "mock", ServiceName: ""}
86+
ink := rpcinfo.NewInvocation("", service.UnknownMethod)
87+
ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
88+
89+
msg := remote.NewMessage(&arg, svcInfo, ri, remote.Call, remote.Client)
90+
91+
msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
92+
return msg
93+
}
94+
95+
func initRecvMsg() remote.Message {
96+
arg := service.Args{Request: make([]byte, 0), Method: "mock", ServiceName: ""}
97+
ink := rpcinfo.NewInvocation("", service.UnknownMethod)
98+
ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
99+
svc := service.NewServiceInfo(svcInfo.PayloadCodec, service.UnknownService, service.UnknownMethod)
100+
msg := remote.NewMessage(&arg, svc, ri, remote.Call, remote.Server)
101+
return msg
102+
}
103+
104+
func prepareReq() *mt.MockReq {
105+
strMap := make(map[string]string)
106+
strMap["key1"] = "val1"
107+
strMap["key2"] = "val2"
108+
strList := []string{"str1", "str2"}
109+
req := &mt.MockReq{
110+
Msg: "MockReq",
111+
StrMap: strMap,
112+
StrList: strList,
113+
}
114+
return req
115+
}

‎server/option.go

+14
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import (
2222
"net"
2323
"time"
2424

25+
"github.com/cloudwego/kitex/pkg/unknownservice"
26+
unknown "github.com/cloudwego/kitex/pkg/unknownservice/service"
27+
2528
"github.com/cloudwego/localsession/backup"
2629

2730
internal_server "github.com/cloudwego/kitex/internal/server"
@@ -348,6 +351,17 @@ func WithGRPCUnknownServiceHandler(f func(ctx context.Context, methodName string
348351
}}
349352
}
350353

354+
// WithUnknownServiceHandler Inject an implementation of a method for handling unknown requests
355+
// supporting only Thrift and Kitex protobuf protocols
356+
func WithUnknownServiceHandler(f unknown.UnknownServiceHandler) Option {
357+
return Option{F: func(o *internal_server.Options, di *utils.Slice) {
358+
di.Push(fmt.Sprintf("WithUnknownMethodHandler(%+v)", utils.GetFuncName(f)))
359+
o.RemoteOpt.UnknownServiceHandler = f
360+
remote.PutPayloadCode(serviceinfo.Thrift, unknownservice.NewUnknownServiceCodec(thrift.NewThriftCodec()))
361+
remote.PutPayloadCode(serviceinfo.Protobuf, unknownservice.NewUnknownServiceCodec(protobuf.NewProtobufCodec()))
362+
}}
363+
}
364+
351365
// Deprecated: Use WithConnectionLimiter instead.
352366
func WithConcurrencyLimiter(conLimit limiter.ConcurrencyLimiter) Option {
353367
return Option{F: func(o *internal_server.Options, di *utils.Slice) {

‎server/server.go

+15
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"sync"
2828
"time"
2929

30+
unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service"
3031
"github.com/cloudwego/localsession/backup"
3132

3233
internal_server "github.com/cloudwego/kitex/internal/server"
@@ -102,6 +103,7 @@ func (s *server) init() {
102103
backup.Init(s.opt.BackupOpt)
103104
s.buildInvokeChain()
104105
s.buildStreamInvokeChain()
106+
s.registerUnknownServiceHandler()
105107
}
106108

107109
func fillContext(opt *internal_server.Options) context.Context {
@@ -546,6 +548,19 @@ func (s *server) waitExit(errCh chan error) error {
546548
}
547549
}
548550

551+
func (s *server) registerUnknownServiceHandler() {
552+
if s.opt.RemoteOpt.UnknownServiceHandler != nil {
553+
if len(s.svcs.svcMap) == 1 && s.svcs.svcMap[serviceinfo.GenericService] != nil {
554+
panic(errors.New("generic services do not support handling of unknown methods"))
555+
} else {
556+
serviceInfo := unknownservice.NewServiceInfo(serviceinfo.Thrift, unknownservice.UnknownService, unknownservice.UnknownMethod)
557+
if err := s.RegisterService(serviceInfo, s.opt.RemoteOpt.UnknownServiceHandler); err != nil {
558+
panic(err)
559+
}
560+
}
561+
}
562+
}
563+
549564
func (s *server) findAndSetDefaultService() {
550565
if len(s.svcs.svcMap) == 1 {
551566
s.targetSvcInfo = getDefaultSvcInfo(s.svcs)

0 commit comments

Comments
 (0)
Please sign in to comment.