Skip to content

Commit 94fb0c2

Browse files
committed
feat(codec): Unknown Service Handler (cloudwego#1321)
1 parent 1256b7d commit 94fb0c2

File tree

6 files changed

+434
-5
lines changed

6 files changed

+434
-5
lines changed

pkg/remote/option.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
*
88
* http://www.apache.org/licenses/LICENSE-2.0
99
*
10-
* Unless required by applicable law or agreed to in writing, software
1110
* distributed under the License is distributed on an "AS IS" BASIS,
1211
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1312
* See the License for the specific language governing permissions and
@@ -27,6 +26,7 @@ import (
2726
"github.com/cloudwego/kitex/pkg/rpcinfo"
2827
"github.com/cloudwego/kitex/pkg/serviceinfo"
2928
"github.com/cloudwego/kitex/pkg/streaming"
29+
"github.com/cloudwego/kitex/pkg/unknownservice/service"
3030
)
3131

3232
// Option is used to pack the inbound and outbound handlers.
@@ -113,6 +113,8 @@ type ServerOption struct {
113113

114114
GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error
115115

116+
UnknownServiceHandler service.UnknownServiceHandler
117+
116118
Option
117119

118120
// invoking chain with recv/send middlewares for streaming APIs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2024 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package service
18+
19+
import (
20+
"context"
21+
22+
"github.com/cloudwego/kitex/pkg/serviceinfo"
23+
)
24+
25+
const (
26+
// UnknownService name
27+
UnknownService = "$UnknownService" // private as "$"
28+
// UnknownMethod name
29+
UnknownMethod = "$UnknownMethod"
30+
)
31+
32+
type Args struct {
33+
Request []byte
34+
Method string
35+
ServiceName string
36+
}
37+
38+
type Result struct {
39+
Success []byte
40+
Method string
41+
ServiceName string
42+
}
43+
44+
type UnknownServiceHandler interface {
45+
UnknownServiceHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error)
46+
}
47+
48+
// NewServiceInfo creates a new ServiceInfo containing unknown methods
49+
func NewServiceInfo(pcType serviceinfo.PayloadCodec, service, method string) *serviceinfo.ServiceInfo {
50+
methods := map[string]serviceinfo.MethodInfo{
51+
method: serviceinfo.NewMethodInfo(callHandler, newServiceArgs, newServiceResult, false),
52+
}
53+
handlerType := (*UnknownServiceHandler)(nil)
54+
55+
svcInfo := &serviceinfo.ServiceInfo{
56+
ServiceName: service,
57+
HandlerType: handlerType,
58+
Methods: methods,
59+
PayloadCodec: pcType,
60+
Extra: make(map[string]interface{}),
61+
}
62+
63+
return svcInfo
64+
}
65+
66+
func callHandler(ctx context.Context, handler, arg, result interface{}) error {
67+
realArg := arg.(*Args)
68+
realResult := result.(*Result)
69+
realResult.Method = realArg.Method
70+
realResult.ServiceName = realArg.ServiceName
71+
success, err := handler.(UnknownServiceHandler).UnknownServiceHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request)
72+
if err != nil {
73+
return err
74+
}
75+
realResult.Success = success
76+
return nil
77+
}
78+
79+
func newServiceArgs() interface{} {
80+
return &Args{}
81+
}
82+
83+
func newServiceResult() interface{} {
84+
return &Result{}
85+
}
+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright 2024 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package unknownservice
18+
19+
import (
20+
"context"
21+
"encoding/binary"
22+
"errors"
23+
"fmt"
24+
25+
gthrift "github.com/cloudwego/gopkg/protocol/thrift"
26+
"github.com/cloudwego/kitex/pkg/remote"
27+
"github.com/cloudwego/kitex/pkg/remote/codec"
28+
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
29+
"github.com/cloudwego/kitex/pkg/rpcinfo"
30+
"github.com/cloudwego/kitex/pkg/serviceinfo"
31+
unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service"
32+
)
33+
34+
// UnknownCodec implements PayloadCodec
35+
type unknownServiceCodec struct {
36+
Codec remote.PayloadCodec
37+
}
38+
39+
// NewUnknownServiceCodec creates the unknown binary codec.
40+
func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec {
41+
return &unknownServiceCodec{code}
42+
}
43+
44+
// Marshal implements the remote.PayloadCodec interface.
45+
func (c unknownServiceCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
46+
ink := msg.RPCInfo().Invocation()
47+
data := msg.Data()
48+
49+
res, ok := data.(*unknownservice.Result)
50+
if !ok {
51+
return c.Codec.Marshal(ctx, msg, out)
52+
}
53+
if msg.MessageType() == remote.Exception {
54+
return c.Codec.Marshal(ctx, msg, out)
55+
}
56+
if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
57+
ink.SetMethodName(res.Method)
58+
ink.SetServiceName(res.ServiceName)
59+
} else {
60+
return errors.New("the interface Invocation doesn't implement InvocationSetter")
61+
}
62+
63+
if res.Success == nil {
64+
sz := gthrift.Binary.MessageBeginLength(msg.RPCInfo().Invocation().MethodName())
65+
if msg.ProtocolInfo().CodecType == serviceinfo.Thrift {
66+
sz += gthrift.Binary.FieldStopLength()
67+
buf, err := out.Malloc(sz)
68+
if err != nil {
69+
return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err))
70+
}
71+
buf = gthrift.Binary.AppendMessageBegin(buf[:0],
72+
msg.RPCInfo().Invocation().MethodName(), gthrift.TMessageType(msg.MessageType()), msg.RPCInfo().Invocation().SeqID())
73+
buf = gthrift.Binary.AppendFieldStop(buf)
74+
_ = buf
75+
}
76+
77+
if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf {
78+
buf, err := out.Malloc(sz)
79+
if err != nil {
80+
return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err))
81+
}
82+
binary.BigEndian.PutUint32(buf, codec.ProtobufV1Magic+uint32(msg.MessageType()))
83+
offset := 4
84+
offset += gthrift.Binary.WriteString(buf[offset:], res.Method)
85+
offset += gthrift.Binary.WriteI32(buf[offset:], msg.RPCInfo().Invocation().SeqID())
86+
_ = buf
87+
}
88+
return nil
89+
}
90+
out.WriteBinary(res.Success)
91+
return nil
92+
}
93+
94+
// Unmarshal implements the remote.PayloadCodec interface.
95+
func (c unknownServiceCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
96+
ink := message.RPCInfo().Invocation()
97+
magicAndMsgType, err := codec.PeekUint32(in)
98+
if err != nil {
99+
return err
100+
}
101+
msgType := magicAndMsgType & codec.FrontMask
102+
if msgType == uint32(remote.Exception) {
103+
return c.Codec.Unmarshal(ctx, message, in)
104+
}
105+
if err = codec.UpdateMsgType(msgType, message); err != nil {
106+
return err
107+
}
108+
service, method, err := readDecode(message, in)
109+
if err != nil {
110+
return err
111+
}
112+
err = codec.SetOrCheckMethodName(method, message)
113+
var te *remote.TransError
114+
if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) {
115+
svcInfo, err := message.SpecifyServiceInfo(unknownservice.UnknownService, unknownservice.UnknownMethod)
116+
if err != nil {
117+
return err
118+
}
119+
120+
if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
121+
ink.SetMethodName(unknownservice.UnknownMethod)
122+
ink.SetPackageName(svcInfo.GetPackageName())
123+
ink.SetServiceName(unknownservice.UnknownService)
124+
} else {
125+
return errors.New("the interface Invocation doesn't implement InvocationSetter")
126+
}
127+
if err = codec.NewDataIfNeeded(unknownservice.UnknownMethod, message); err != nil {
128+
return err
129+
}
130+
131+
data := message.Data()
132+
133+
if data, ok := data.(*unknownservice.Args); ok {
134+
data.Method = method
135+
data.ServiceName = service
136+
buf, err := in.Next(in.ReadableLen())
137+
if err != nil {
138+
return err
139+
}
140+
data.Request = buf
141+
}
142+
return nil
143+
}
144+
145+
return c.Codec.Unmarshal(ctx, message, in)
146+
}
147+
148+
// Name implements the remote.PayloadCodec interface.
149+
func (c unknownServiceCodec) Name() string {
150+
return "unknownServiceCodec"
151+
}
152+
153+
func readDecode(message remote.Message, in remote.ByteBuffer) (string, string, error) {
154+
code := message.ProtocolInfo().CodecType
155+
if code == serviceinfo.Thrift || code == serviceinfo.Protobuf {
156+
method, size, err := peekMethod(in)
157+
if err != nil {
158+
return "", "", err
159+
}
160+
161+
seqID, err := peekSeqID(in, size)
162+
if err != nil {
163+
return "", "", err
164+
}
165+
if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
166+
return "", "", err
167+
}
168+
return message.RPCInfo().Invocation().ServiceName(), method, nil
169+
}
170+
return "", "", nil
171+
}
172+
173+
func peekMethod(in remote.ByteBuffer) (string, int32, error) {
174+
buf, err := in.Peek(8)
175+
if err != nil {
176+
return "", 0, err
177+
}
178+
buf = buf[4:]
179+
size := int32(binary.BigEndian.Uint32(buf))
180+
buf, err = in.Peek(int(size + 8))
181+
if err != nil {
182+
return "", 0, perrors.NewProtocolError(err)
183+
}
184+
buf = buf[8:]
185+
method := string(buf)
186+
return method, size + 8, nil
187+
}
188+
189+
func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) {
190+
buf, err := in.Peek(int(size + 4))
191+
if err != nil {
192+
return 0, perrors.NewProtocolError(err)
193+
}
194+
buf = buf[size:]
195+
seqID := int32(binary.BigEndian.Uint32(buf))
196+
return seqID, nil
197+
}

0 commit comments

Comments
 (0)