From fe8a4e99ff7757c2b7a34e1c9613d13fa010708a Mon Sep 17 00:00:00 2001 From: zhanglei25 Date: Fri, 2 Sep 2022 18:24:17 +0800 Subject: [PATCH] feat(tars): Support rpc async callback call --- tars/adapter.go | 2 +- tars/application.go | 1 + tars/config.go | 4 +- tars/message.go | 61 +++++++++++++++++ tars/model/servant.go | 13 ++++ tars/servant.go | 155 ++++++++++++++++++++++++++---------------- tars/setting.go | 4 +- 7 files changed, 177 insertions(+), 63 deletions(-) diff --git a/tars/adapter.go b/tars/adapter.go index 843a4ca3..8046652c 100755 --- a/tars/adapter.go +++ b/tars/adapter.go @@ -256,7 +256,7 @@ func (c *AdapterProxy) doKeepAlive() { IRequestId: c.servantProxy.genRequestID(), SServantName: c.servantProxy.name, SFuncName: "tars_ping", - ITimeout: int32(c.servantProxy.timeout), + ITimeout: int32(c.servantProxy.asyncTimeout), } msg := &Message{Req: &req, Ser: c.servantProxy} msg.Init() diff --git a/tars/application.go b/tars/application.go index dffd0085..f2be1d3c 100755 --- a/tars/application.go +++ b/tars/application.go @@ -228,6 +228,7 @@ func (a *application) initConfig() { a.cltCfg.Stat = cMap["stat"] a.cltCfg.Property = cMap["property"] a.cltCfg.ModuleName = cMap["modulename"] + a.cltCfg.SyncInvokeTimeout = c.GetIntWithDef("/tars/application/client", SyncInvokeTimeout) a.cltCfg.AsyncInvokeTimeout = c.GetIntWithDef("/tars/application/client", AsyncInvokeTimeout) a.cltCfg.RefreshEndpointInterval = c.GetIntWithDef("/tars/application/client", refreshEndpointInterval) a.cltCfg.ReportInterval = c.GetIntWithDef("/tars/application/client", reportInterval) diff --git a/tars/config.go b/tars/config.go index 39dba6ce..1cc68dd3 100755 --- a/tars/config.go +++ b/tars/config.go @@ -80,8 +80,9 @@ type clientConfig struct { ReportInterval int CheckStatusInterval int KeepAliveInterval int - AsyncInvokeTimeout int // add client timeout + SyncInvokeTimeout int + AsyncInvokeTimeout int ClientQueueLen int ClientIdleTimeout time.Duration ClientReadTimeout time.Duration @@ -152,6 +153,7 @@ func newClientConfig() *clientConfig { ReportInterval: reportInterval, CheckStatusInterval: checkStatusInterval, KeepAliveInterval: keepAliveInterval, + SyncInvokeTimeout: SyncInvokeTimeout, AsyncInvokeTimeout: AsyncInvokeTimeout, ClientQueueLen: ClientQueueLen, ClientIdleTimeout: tools.ParseTimeOut(ClientIdleTimeout), diff --git a/tars/message.go b/tars/message.go index a0a2b11b..9e001ba2 100644 --- a/tars/message.go +++ b/tars/message.go @@ -1,10 +1,15 @@ package tars import ( + "context" "time" + "github.com/TarsCloud/TarsGo/tars/model" + "github.com/TarsCloud/TarsGo/tars/protocol/res/basef" "github.com/TarsCloud/TarsGo/tars/protocol/res/requestf" "github.com/TarsCloud/TarsGo/tars/selector" + "github.com/TarsCloud/TarsGo/tars/util/current" + "github.com/TarsCloud/TarsGo/tars/util/tools" ) // HashType is the hash type @@ -31,6 +36,8 @@ type Message struct { hashCode uint32 hashType HashType isHash bool + Async bool + Callback model.Callback } // Init define the beginTime @@ -66,3 +73,57 @@ func (m *Message) HashType() selector.HashType { func (m *Message) IsHash() bool { return m.isHash } + +func buildMessage(ctx context.Context, cType byte, + sFuncName string, + buf []byte, + status map[string]string, + reqContext map[string]string, + resp *requestf.ResponsePacket, + s *ServantProxy) *Message { + + // 将ctx中的dyeing信息传入到request中 + var msgType int32 + if dyeingKey, ok := current.GetDyeingKey(ctx); ok { + TLOG.Debug("dyeing debug: find dyeing key:", dyeingKey) + if status == nil { + status = make(map[string]string) + } + status[current.StatusDyedKey] = dyeingKey + msgType |= basef.TARSMESSAGETYPEDYED + } + + // 将ctx中的trace信息传入到request中 + if trace, ok := current.GetTarsTrace(ctx); ok && trace.Call() { + traceKey := trace.GetTraceFullKey(false) + TLOG.Debug("trace debug: find trace key:", traceKey) + if status == nil { + status = make(map[string]string) + } + status[current.StatusTraceKey] = traceKey + msgType |= basef.TARSMESSAGETYPETRACE + } + + req := requestf.RequestPacket{ + IVersion: s.version, + CPacketType: int8(cType), + IMessageType: msgType, + IRequestId: s.genRequestID(), + SServantName: s.name, + SFuncName: sFuncName, + ITimeout: int32(s.syncTimeout), + SBuffer: tools.ByteToInt8(buf), + Context: reqContext, + Status: status, + } + msg := &Message{Req: &req, Ser: s, Resp: resp} + msg.Init() + + if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok { + msg.isHash = isHash + msg.hashType = HashType(hashType) + msg.hashCode = hashCode + } + + return msg +} diff --git a/tars/model/servant.go b/tars/model/servant.go index 22e43f68..696cd789 100755 --- a/tars/model/servant.go +++ b/tars/model/servant.go @@ -8,6 +8,10 @@ import ( "github.com/TarsCloud/TarsGo/tars/protocol/res/requestf" ) +type Callback interface { + Dispatch(context.Context, *requestf.RequestPacket, *requestf.ResponsePacket, error) (int32, error) +} + // Servant is interface for call the remote server. type Servant interface { Name() string @@ -17,6 +21,15 @@ type Servant interface { status map[string]string, context map[string]string, resp *requestf.ResponsePacket) error + + TarsInvokeAsync(ctx context.Context, cType byte, + sFuncName string, + buf []byte, + status map[string]string, + context map[string]string, + resp *requestf.ResponsePacket, + callback Callback) error + TarsSetTimeout(t int) TarsSetProtocol(Protocol) Endpoints() []*endpoint.Endpoint diff --git a/tars/servant.go b/tars/servant.go index 006acf90..6e4551d8 100755 --- a/tars/servant.go +++ b/tars/servant.go @@ -15,7 +15,6 @@ import ( "github.com/TarsCloud/TarsGo/tars/util/current" "github.com/TarsCloud/TarsGo/tars/util/endpoint" "github.com/TarsCloud/TarsGo/tars/util/rtimer" - "github.com/TarsCloud/TarsGo/tars/util/tools" ) var ( @@ -31,13 +30,14 @@ const ( // ServantProxy tars servant proxy instance type ServantProxy struct { - name string - comm *Communicator - manager EndpointManager - timeout int - version int16 - proto model.Protocol - queueLen int32 + name string + comm *Communicator + manager EndpointManager + syncTimeout int + asyncTimeout int + version int16 + proto model.Protocol + queueLen int32 pushCallback func([]byte) } @@ -51,7 +51,6 @@ func newServantProxy(comm *Communicator, objName string, opts ...EndpointManager s := &ServantProxy{ comm: comm, proto: &protocol.TarsProtocol{}, - timeout: comm.Client.AsyncInvokeTimeout, version: basef.TARSVERSION, } pos := strings.Index(objName, "@") @@ -67,6 +66,12 @@ func newServantProxy(comm *Communicator, objName string, opts ...EndpointManager // init manager s.manager = GetManager(comm, objName, opts...) + + s.comm = comm + s.proto = &protocol.TarsProtocol{} + s.syncTimeout = s.comm.Client.SyncInvokeTimeout + s.asyncTimeout = s.comm.Client.AsyncInvokeTimeout + s.version = basef.TARSVERSION return s } @@ -77,7 +82,7 @@ func (s *ServantProxy) Name() string { // TarsSetTimeout sets the timeout for client calling the server , which is in ms. func (s *ServantProxy) TarsSetTimeout(t int) { - s.timeout = t + s.syncTimeout = t } // TarsSetVersion set tars version @@ -122,53 +127,44 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte, resp *requestf.ResponsePacket) error { defer CheckPanic() - // 将ctx中的dyeing信息传入到request中 - var msgType int32 - if dyeingKey, ok := current.GetDyeingKey(ctx); ok { - TLOG.Debug("dyeing debug: find dyeing key:", dyeingKey) - if status == nil { - status = make(map[string]string) - } - status[current.StatusDyedKey] = dyeingKey - msgType |= basef.TARSMESSAGETYPEDYED - } + msg := buildMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s) + timeout := time.Duration(s.syncTimeout) * time.Millisecond + err := s.invokeFilters(ctx, msg, timeout) - // 将ctx中的trace信息传入到request中 - if trace, ok := current.GetTarsTrace(ctx); ok && trace.Call() { - traceKey := trace.GetTraceFullKey(false) - TLOG.Debug("trace debug: find trace key:", traceKey) - if status == nil { - status = make(map[string]string) - } - status[current.StatusTraceKey] = traceKey - msgType |= basef.TARSMESSAGETYPETRACE + if err != nil { + return err } + *resp = *msg.Resp + return nil +} - req := requestf.RequestPacket{ - IVersion: s.version, - CPacketType: int8(cType), - IRequestId: s.genRequestID(), - SServantName: s.name, - SFuncName: sFuncName, - SBuffer: tools.ByteToInt8(buf), - ITimeout: int32(s.timeout), - Context: reqContext, - Status: status, - IMessageType: msgType, - } - msg := &Message{Req: &req, Ser: s, Resp: resp} - msg.Init() - - timeout := time.Duration(s.timeout) * time.Millisecond - if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok { - msg.isHash = isHash - msg.hashType = HashType(hashType) - msg.hashCode = hashCode +// TarsInvokeAsync is used for client invoking server. +func (s *ServantProxy) TarsInvokeAsync(ctx context.Context, cType byte, + sFuncName string, + buf []byte, + status map[string]string, + reqContext map[string]string, + resp *requestf.ResponsePacket, + callback model.Callback) error { + defer CheckPanic() + + msg := buildMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s) + msg.Req.ITimeout = int32(s.asyncTimeout) + if callback == nil { + msg.Req.CPacketType = basef.TARSONEWAY + } else { + msg.Async = true + msg.Callback = callback } + timeout := time.Duration(s.asyncTimeout) * time.Millisecond + return s.invokeFilters(ctx, msg, timeout) +} + +func (s *ServantProxy) invokeFilters(ctx context.Context, msg *Message, timeout time.Duration) error { if ok, to, isTimeout := current.GetClientTimeout(ctx); ok && isTimeout { timeout = time.Duration(to) * time.Millisecond - req.ITimeout = int32(to) + msg.Req.ITimeout = int32(to) } var err error @@ -196,11 +192,19 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte, } } } - s.manager.postInvoke() + // no async rpc call + if !msg.Async { + s.manager.postInvoke() + msg.End() + s.reportStat(msg, err) + } + return err +} + +func (s *ServantProxy) reportStat(msg *Message, err error) { if err != nil { - msg.End() - TLOG.Errorf("Invoke error: %s, %s, %v, cost:%d", s.name, sFuncName, err.Error(), msg.Cost()) + TLOG.Errorf("Invoke error: %s, %s, %v, cost:%d", s.name, msg.Req.SFuncName, err.Error(), msg.Cost()) if msg.Resp == nil { ReportStat(msg, StatSuccess, StatSuccess, StatFailed) } else if msg.Status == basef.TARSINVOKETIMEOUT { @@ -208,15 +212,12 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte, } else { ReportStat(msg, StatSuccess, StatSuccess, StatFailed) } - return err + return } - msg.End() - *resp = *msg.Resp ReportStat(msg, StatFailed, StatSuccess, StatSuccess) - return err } -func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.Duration) error { +func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.Duration) (err error) { adp, needCheck := s.manager.SelectAdapterProxy(msg) if adp == nil { return errors.New("no adapter Proxy selected:" + msg.Req.SServantName) @@ -239,19 +240,53 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time. atomic.AddInt32(&s.queueLen, 1) readCh := make(chan *requestf.ResponsePacket) adp.resp.Store(msg.Req.IRequestId, readCh) - defer func() { + var releaseFunc = func() { CheckPanic() atomic.AddInt32(&s.queueLen, -1) adp.resp.Delete(msg.Req.IRequestId) + } + defer func() { + if !msg.Async || err != nil { + releaseFunc() + } }() - if err := adp.Send(msg.Req); err != nil { + + if err = adp.Send(msg.Req); err != nil { adp.failAdd() return err } + if msg.Req.CPacketType == basef.TARSONEWAY { adp.successAdd() return nil } + + // async call rpc + if msg.Async { + go func() { + defer releaseFunc() + err := s.waitInvoke(msg, adp, timeout, needCheck) + s.manager.postInvoke() + msg.End() + s.reportStat(msg, err) + if msg.Status != basef.TARSINVOKETIMEOUT { + current.SetResponseContext(ctx, msg.Resp.Context) + current.SetResponseStatus(ctx, msg.Resp.Status) + } + if _, err := msg.Callback.Dispatch(ctx, msg.Req, msg.Resp, err); err != nil { + TLOG.Errorf("Callback error: %s, %s, %+v", s.name, msg.Req.SFuncName, err) + } + }() + return nil + } + + return s.waitInvoke(msg, adp, timeout, needCheck) +} + +func (s *ServantProxy) waitInvoke(msg *Message, adp *AdapterProxy, timeout time.Duration, needCheck bool) error { + ch, _ := adp.resp.Load(msg.Req.IRequestId) + readCh := ch.(chan *requestf.ResponsePacket) + select { case <-rtimer.After(timeout): msg.Status = basef.TARSINVOKETIMEOUT diff --git a/tars/setting.go b/tars/setting.go index 0eda0a64..28822799 100755 --- a/tars/setting.go +++ b/tars/setting.go @@ -85,8 +85,10 @@ const ( // communicator default ,update from remote config refreshEndpointInterval int = 60000 reportInterval int = 5000 + // SyncInvokeTimeout sync invoke timeout + SyncInvokeTimeout int = 3000 // AsyncInvokeTimeout async invoke timeout - AsyncInvokeTimeout int = 3000 + AsyncInvokeTimeout int = 5000 // check endpoint status every 1000 ms checkStatusInterval int = 1000