diff --git a/examples/OpentelemetryServer/go.mod b/examples/OpentelemetryServer/go.mod index f08418ae..0104882d 100644 --- a/examples/OpentelemetryServer/go.mod +++ b/examples/OpentelemetryServer/go.mod @@ -3,7 +3,7 @@ module OpentelemetryServer go 1.19 require ( - github.com/TarsCloud/TarsGo v1.3.10 + github.com/TarsCloud/TarsGo v1.4.4 github.com/TarsCloud/TarsGo/contrib/middleware/opentelemetry v0.0.0 github.com/prometheus/client_golang v1.14.0 go.opentelemetry.io/otel v1.15.0-rc.1 @@ -52,7 +52,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.14.0 // indirect go.opentelemetry.io/otel/metric v1.15.0-rc.1 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect - go.uber.org/automaxprocs v1.5.1 // indirect + go.uber.org/automaxprocs v1.5.2 // indirect golang.org/x/crypto v0.1.0 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.5.0 // indirect diff --git a/examples/OpentelemetryServer/tracer/tracer.go b/examples/OpentelemetryServer/tracer/tracer.go new file mode 100644 index 00000000..b0e67c8e --- /dev/null +++ b/examples/OpentelemetryServer/tracer/tracer.go @@ -0,0 +1,92 @@ +package tracer + +import ( + "context" + "log" + "os" + "sync" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/jaeger" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + "go.opentelemetry.io/otel/exporters/zipkin" + "go.opentelemetry.io/otel/propagation" + sdkresource "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +var resource *sdkresource.Resource +var initResourcesOnce sync.Once + +func initResource(serviceName string) *sdkresource.Resource { + initResourcesOnce.Do(func() { + extraResources, _ := sdkresource.New( + context.Background(), + sdkresource.WithOS(), + sdkresource.WithProcess(), + sdkresource.WithContainer(), + sdkresource.WithHost(), + sdkresource.WithFromEnv(), + sdkresource.WithAttributes(semconv.ServiceName(serviceName)), + ) + resource, _ = sdkresource.Merge( + sdkresource.Default(), + extraResources, + ) + }) + return resource +} + +func newOtlpExporter() (sdktrace.SpanExporter, error) { + ctx := context.Background() + return otlptracegrpc.New(ctx, otlptracegrpc.WithInsecure()) +} + +func newStdoutExporter() (sdktrace.SpanExporter, error) { + return stdouttrace.New( + // Use human-readable output. + stdouttrace.WithPrettyPrint(), + // Do not print timestamps for the demo. + stdouttrace.WithoutTimestamps(), + ) +} + +func newZipkinExporter(serviceNameKey string) (sdktrace.SpanExporter, error) { + url := "http://localhost:19411/api/v2/spans" + return zipkin.New(url, zipkin.WithLogger(log.New(os.Stderr, serviceNameKey, log.Ldate|log.Ltime|log.Llongfile))) +} + +func newJaegerExporter() (sdktrace.SpanExporter, error) { + url := "http://localhost:14268/api/traces" + return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(url))) +} + +func NewTracerProvider(serviceName, exporterTyp string) *sdktrace.TracerProvider { + var ( + exporter sdktrace.SpanExporter + err error + ) + switch exporterTyp { + case "stdout": + exporter, err = newStdoutExporter() + case "zipkin": + exporter, err = newZipkinExporter(serviceName) + case "jaeger": + exporter, err = newJaegerExporter() + default: // otlp + exporter, err = newOtlpExporter() + } + if err != nil { + log.Fatal(err) + } + tp := sdktrace.NewTracerProvider( + //sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(0.5))), // 控制采样 + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(initResource(serviceName)), + ) + otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) + return tp +} diff --git a/examples/TarsPushServer/client/DemoClient.go b/examples/TarsPushServer/client/DemoClient.go new file mode 100644 index 00000000..186bc155 --- /dev/null +++ b/examples/TarsPushServer/client/DemoClient.go @@ -0,0 +1,95 @@ +package main + +import ( + "TarPushServer/demo" + "fmt" + "github.com/TarsCloud/TarsGo/tars" + "time" +) + +var prx *demo.DemoObj + +type Callback struct { + start int64 + cost int64 + count int64 +} + +func (c *Callback) Reg_Callback(ret *demo.Result, rsp *demo.RegRsp, opt ...map[string]string) { + //TODO implement me + panic("implement me") +} + +func (c *Callback) Notify_Callback(ret *demo.Result, opt ...map[string]string) { + //TODO implement me + panic("implement me") +} + +func (c *Callback) Notify_ExceptionCallback(err error) { + //TODO implement me + panic("implement me") +} + +func (c *Callback) Push_Callback(msg *string, opt ...map[string]string) { + /* if c.count == 0 { + c.start = time.Now().UnixMicro() + } + c.count++ + if c.count == 500000 { + c.cost = time.Now().UnixMicro() - c.start + fmt.Printf("cost--->%vus\n", c.cost) + }*/ + + //if c.start == 0 { + // c.start = time.Now().UnixMicro() + //} else { + // c.cost = time.Now().UnixMicro() - c.start + //} + //fmt.Printf("cost--->%vus|Push---->:%s======%v\n", c.cost, *msg, opt) + fmt.Printf("%v|Push---->:%s======%v\n", time.Now().UnixMilli(), *msg, opt) +} + +func (c Callback) Push_ExceptionCallback(err error) { + panic("implement me") +} + +func (c Callback) Reg_ExceptionCallback(err error) { + //TODO implement me + panic("implement me") +} + +func TestReg() { + req := &demo.RegReq{Msg: "reg"} + rsp := &demo.RegRsp{} + prx.Reg(req, rsp) +} + +func main() { + com := tars.GetCommunicator() + obj := "Base.DemoServer.DemoObj@tcp -h 127.0.0.1 -p 8888 -t 60000" + prx = &demo.DemoObj{} + com.StringToProxy(obj, prx) + prx.SetOnConnectCallback(func(s string) { + fmt.Println("<-----------onConnect--------->") + TestReg() + }) + prx.SetOnCloseCallback(func(s string) { + fmt.Println("<-----------onClose----------->") + }) + TarsCb := new(Callback) + prx.SetTarsCallback(TarsCb) + go func() { + ticker := time.NewTicker(time.Second * 10) + defer ticker.Stop() + for { + select { + case <-ticker.C: + prx.TarsPing() + } + } + }() + prx.TarsPing() + for true { + time.Sleep(time.Second * 10) + } +} diff --git a/examples/TarsPushServer/demo.tars b/examples/TarsPushServer/demo.tars new file mode 100644 index 00000000..91279455 --- /dev/null +++ b/examples/TarsPushServer/demo.tars @@ -0,0 +1,30 @@ +module demo +{ + struct Result + { + 0 optional int code; + 1 optional string msg; + }; + + struct RegReq + { + 0 optional string msg; + }; + + struct RegRsp + { + 0 optional string msg; + }; + + struct Notify + { + 0 optional string msg; + }; + + interface DemoObj + { + void Push(out string msg); + Result Reg(RegReq req, out RegRsp rsp); + Result Notify(Notify notify); + }; +}; \ No newline at end of file diff --git a/examples/TarsPushServer/demo/DemoObj.tars.go b/examples/TarsPushServer/demo/DemoObj.tars.go new file mode 100644 index 00000000..f090c0e6 --- /dev/null +++ b/examples/TarsPushServer/demo/DemoObj.tars.go @@ -0,0 +1,1384 @@ +// Package demo comment +// This file was generated by tars2go 1.1.10 +// Generated from demo.tars +package demo + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "github.com/TarsCloud/TarsGo/tars" + m "github.com/TarsCloud/TarsGo/tars/model" + "github.com/TarsCloud/TarsGo/tars/protocol/codec" + "github.com/TarsCloud/TarsGo/tars/protocol/res/basef" + "github.com/TarsCloud/TarsGo/tars/protocol/res/requestf" + "github.com/TarsCloud/TarsGo/tars/protocol/tup" + "github.com/TarsCloud/TarsGo/tars/util/current" + "github.com/TarsCloud/TarsGo/tars/util/tools" + "github.com/TarsCloud/TarsGo/tars/util/trace" + "net" + "unsafe" +) + +// Reference imports to suppress errors if they are not otherwise used. +var ( + _ = fmt.Errorf + _ = codec.FromInt8 + _ = unsafe.Pointer(nil) + _ = bytes.ErrTooLarge +) + +// DemoObj struct +type DemoObj struct { + servant m.Servant +} + +// Push is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) Push(msg *string, opts ...map[string]string) (err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = buf.WriteString(*msg, 1) + if err != nil { + return err + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + tarsResp := new(requestf.ResponsePacket) + tarsCtx := context.Background() + + err = obj.servant.TarsInvoke(tarsCtx, 0, "Push", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return err + } + + readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer)) + err = readBuf.ReadString(&(*msg), 1, true) + if err != nil { + return err + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return nil +} + +// PushWithContext is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) PushWithContext(tarsCtx context.Context, msg *string, opts ...map[string]string) (err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = buf.WriteString(*msg, 1) + if err != nil { + return err + } + + traceData, ok := current.GetTraceData(tarsCtx) + if ok && traceData.TraceCall { + traceData.NewSpan() + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCS), trace.TraceAnnotationCS, tars.GetClientConfig().ModuleName, obj.servant.Name(), "Push", 0, traceParam, "") + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + + tarsResp := new(requestf.ResponsePacket) + err = obj.servant.TarsInvoke(tarsCtx, 0, "Push", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return err + } + + readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer)) + err = readBuf.ReadString(&(*msg), 1, true) + if err != nil { + return err + } + + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value["msg"] = *msg + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCR), trace.TraceAnnotationCR, tars.GetClientConfig().ModuleName, obj.servant.Name(), "Push", int(tarsResp.IRet), traceParam, "") + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return nil +} + +// PushOneWayWithContext is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) PushOneWayWithContext(tarsCtx context.Context, msg *string, opts ...map[string]string) (err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = buf.WriteString(*msg, 1) + if err != nil { + return err + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + + tarsResp := new(requestf.ResponsePacket) + err = obj.servant.TarsInvoke(tarsCtx, 1, "Push", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return err + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return nil +} + +// Reg is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) Reg(req *RegReq, rsp *RegRsp, opts ...map[string]string) (ret Result, err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = req.WriteBlock(buf, 1) + if err != nil { + return ret, err + } + + err = (*rsp).WriteBlock(buf, 2) + if err != nil { + return ret, err + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + tarsResp := new(requestf.ResponsePacket) + tarsCtx := context.Background() + + err = obj.servant.TarsInvoke(tarsCtx, 0, "Reg", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return ret, err + } + + readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer)) + err = ret.ReadBlock(readBuf, 0, true) + if err != nil { + return ret, err + } + + err = (*rsp).ReadBlock(readBuf, 2, true) + if err != nil { + return ret, err + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return ret, nil +} + +// RegWithContext is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) RegWithContext(tarsCtx context.Context, req *RegReq, rsp *RegRsp, opts ...map[string]string) (ret Result, err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = req.WriteBlock(buf, 1) + if err != nil { + return ret, err + } + + err = (*rsp).WriteBlock(buf, 2) + if err != nil { + return ret, err + } + + traceData, ok := current.GetTraceData(tarsCtx) + if ok && traceData.TraceCall { + traceData.NewSpan() + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value["req"] = req + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCS), trace.TraceAnnotationCS, tars.GetClientConfig().ModuleName, obj.servant.Name(), "Reg", 0, traceParam, "") + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + + tarsResp := new(requestf.ResponsePacket) + err = obj.servant.TarsInvoke(tarsCtx, 0, "Reg", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return ret, err + } + + readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer)) + err = ret.ReadBlock(readBuf, 0, true) + if err != nil { + return ret, err + } + + err = (*rsp).ReadBlock(readBuf, 2, true) + if err != nil { + return ret, err + } + + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value[""] = ret + value["rsp"] = *rsp + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCR), trace.TraceAnnotationCR, tars.GetClientConfig().ModuleName, obj.servant.Name(), "Reg", int(tarsResp.IRet), traceParam, "") + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return ret, nil +} + +// RegOneWayWithContext is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) RegOneWayWithContext(tarsCtx context.Context, req *RegReq, rsp *RegRsp, opts ...map[string]string) (ret Result, err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = req.WriteBlock(buf, 1) + if err != nil { + return ret, err + } + + err = (*rsp).WriteBlock(buf, 2) + if err != nil { + return ret, err + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + + tarsResp := new(requestf.ResponsePacket) + err = obj.servant.TarsInvoke(tarsCtx, 1, "Reg", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return ret, err + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return ret, nil +} + +// Notify is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) Notify(notify *Notify, opts ...map[string]string) (ret Result, err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = notify.WriteBlock(buf, 1) + if err != nil { + return ret, err + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + tarsResp := new(requestf.ResponsePacket) + tarsCtx := context.Background() + + err = obj.servant.TarsInvoke(tarsCtx, 0, "Notify", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return ret, err + } + + readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer)) + err = ret.ReadBlock(readBuf, 0, true) + if err != nil { + return ret, err + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return ret, nil +} + +// NotifyWithContext is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) NotifyWithContext(tarsCtx context.Context, notify *Notify, opts ...map[string]string) (ret Result, err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = notify.WriteBlock(buf, 1) + if err != nil { + return ret, err + } + + traceData, ok := current.GetTraceData(tarsCtx) + if ok && traceData.TraceCall { + traceData.NewSpan() + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value["notify"] = notify + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCS), trace.TraceAnnotationCS, tars.GetClientConfig().ModuleName, obj.servant.Name(), "Notify", 0, traceParam, "") + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + + tarsResp := new(requestf.ResponsePacket) + err = obj.servant.TarsInvoke(tarsCtx, 0, "Notify", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return ret, err + } + + readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer)) + err = ret.ReadBlock(readBuf, 0, true) + if err != nil { + return ret, err + } + + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value[""] = ret + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCR), trace.TraceAnnotationCR, tars.GetClientConfig().ModuleName, obj.servant.Name(), "Notify", int(tarsResp.IRet), traceParam, "") + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return ret, nil +} + +// NotifyOneWayWithContext is the proxy function for the method defined in the tars file, with the context +func (obj *DemoObj) NotifyOneWayWithContext(tarsCtx context.Context, notify *Notify, opts ...map[string]string) (ret Result, err error) { + var ( + length int32 + have bool + ty byte + ) + buf := codec.NewBuffer() + err = notify.WriteBlock(buf, 1) + if err != nil { + return ret, err + } + + var statusMap map[string]string + var contextMap map[string]string + if len(opts) == 1 { + contextMap = opts[0] + } else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] + } + + tarsResp := new(requestf.ResponsePacket) + err = obj.servant.TarsInvoke(tarsCtx, 1, "Notify", buf.ToBytes(), statusMap, contextMap, tarsResp) + if err != nil { + return ret, err + } + + if len(opts) == 1 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + } else if len(opts) == 2 { + for k := range contextMap { + delete(contextMap, k) + } + for k, v := range tarsResp.Context { + contextMap[k] = v + } + for k := range statusMap { + delete(statusMap, k) + } + for k, v := range tarsResp.Status { + statusMap[k] = v + } + } + _ = length + _ = have + _ = ty + return ret, nil +} + +// SetServant sets servant for the service. +func (obj *DemoObj) SetServant(servant m.Servant) { + obj.servant = servant +} + +// GetServant gets servant for the service. +func (obj *DemoObj) GetServant() (servant *m.Servant) { + return &obj.servant +} + +// SetOnConnectCallback +func (obj *DemoObj) SetOnConnectCallback(callback func(string)) { + obj.servant.SetOnConnectCallback(callback) +} + +// SetOnCloseCallback +func (obj *DemoObj) SetOnCloseCallback(callback func(string)) { + obj.servant.SetOnCloseCallback(callback) +} + +// SetTarsCallback +func (obj *DemoObj) SetTarsCallback(callback DemoObjTarsCallback) { + var push DemoObjPushCallback + push.Cb = callback + obj.servant.SetTarsCallback(&push) +} + +// SetPushCallback +func (obj *DemoObj) SetPushCallback(callback func([]byte)) { + obj.servant.SetPushCallback(callback) +} +func (obj *DemoObj) req2Byte(rsp *requestf.ResponsePacket) []byte { + req := requestf.RequestPacket{} + req.IVersion = rsp.IVersion + req.IRequestId = rsp.IRequestId + req.IMessageType = rsp.IMessageType + req.CPacketType = rsp.CPacketType + req.Context = rsp.Context + req.Status = rsp.Status + req.SBuffer = rsp.SBuffer + + os := codec.NewBuffer() + req.WriteTo(os) + bs := os.ToBytes() + sbuf := bytes.NewBuffer(nil) + sbuf.Write(make([]byte, 4)) + sbuf.Write(bs) + length := sbuf.Len() + binary.BigEndian.PutUint32(sbuf.Bytes(), uint32(length)) + return sbuf.Bytes() +} + +func (obj *DemoObj) rsp2Byte(rsp *requestf.ResponsePacket) []byte { + if rsp.IVersion == basef.TUPVERSION { + return obj.req2Byte(rsp) + } + os := codec.NewBuffer() + rsp.WriteTo(os) + bs := os.ToBytes() + sbuf := bytes.NewBuffer(nil) + sbuf.Write(make([]byte, 4)) + sbuf.Write(bs) + length := sbuf.Len() + binary.BigEndian.PutUint32(sbuf.Bytes(), uint32(length)) + return sbuf.Bytes() +} + +// TarsPing +func (obj *DemoObj) TarsPing() { + ctx := context.Background() + obj.servant.TarsPing(ctx) +} + +// TarsSetTimeout sets the timeout for the servant which is in ms. +func (obj *DemoObj) TarsSetTimeout(timeout int) { + obj.servant.TarsSetTimeout(timeout) +} + +// TarsSetProtocol sets the protocol for the servant. +func (obj *DemoObj) TarsSetProtocol(p m.Protocol) { + obj.servant.TarsSetProtocol(p) +} + +// AddServant adds servant for the service. +func (obj *DemoObj) AddServant(imp DemoObjServant, servantObj string) { + tars.AddServant(obj, imp, servantObj) +} + +// AddServantWithContext adds servant for the service with context. +func (obj *DemoObj) AddServantWithContext(imp DemoObjServantWithContext, servantObj string) { + tars.AddServantWithContext(obj, imp, servantObj) +} + +type DemoObjServant interface { + Push(msg *string) (err error) + Reg(req *RegReq, rsp *RegRsp) (ret Result, err error) + Notify(notify *Notify) (ret Result, err error) +} +type DemoObjServantWithContext interface { + Push(tarsCtx context.Context, msg *string) (err error) + Reg(tarsCtx context.Context, req *RegReq, rsp *RegRsp) (ret Result, err error) + Notify(tarsCtx context.Context, notify *Notify) (ret Result, err error) +} + +// Dispatch is used to call the server side implement for the method defined in the tars file. withContext shows using context or not. +func (obj *DemoObj) Dispatch(tarsCtx context.Context, val interface{}, tarsReq *requestf.RequestPacket, tarsResp *requestf.ResponsePacket, withContext bool) (err error) { + var ( + length int32 + have bool + ty byte + ) + readBuf := codec.NewReader(tools.Int8ToByte(tarsReq.SBuffer)) + buf := codec.NewBuffer() + switch tarsReq.SFuncName { + case "Push": + var msg string + + traceData, ok := current.GetTraceData(tarsCtx) + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSR), trace.TraceAnnotationSR, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "Push", 0, traceParam, "") + } + + if !withContext { + imp := val.(DemoObjServant) + err = imp.Push(&msg) + } else { + imp := val.(DemoObjServantWithContext) + err = imp.Push(tarsCtx, &msg) + } + + if err != nil { + return err + } + + if tarsReq.IVersion == basef.TARSVERSION { + buf.Reset() + + err = buf.WriteString(msg, 1) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.TUPVERSION { + rspTup := tup.NewUniAttribute() + + buf.Reset() + err = buf.WriteString(msg, 0) + if err != nil { + return err + } + + rspTup.PutBuffer("msg", buf.ToBytes()) + + buf.Reset() + err = rspTup.Encode(buf) + if err != nil { + return err + } + } else if tarsReq.IVersion == basef.JSONVERSION { + rspJson := map[string]interface{}{} + rspJson["msg"] = msg + + var rspByte []byte + if rspByte, err = json.Marshal(rspJson); err != nil { + return err + } + + buf.Reset() + err = buf.WriteSliceUint8(rspByte) + if err != nil { + return err + } + } + + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value["msg"] = msg + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSS), trace.TraceAnnotationSS, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "Push", 0, traceParam, "") + } + + case "Reg": + var req RegReq + var rsp RegRsp + + if tarsReq.IVersion == basef.TARSVERSION { + + err = req.ReadBlock(readBuf, 1, true) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.TUPVERSION { + reqTup := tup.NewUniAttribute() + reqTup.Decode(readBuf) + + var tupBuffer []byte + + reqTup.GetBuffer("req", &tupBuffer) + readBuf.Reset(tupBuffer) + err = req.ReadBlock(readBuf, 0, true) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.JSONVERSION { + var jsonData map[string]interface{} + decoder := json.NewDecoder(bytes.NewReader(readBuf.ToBytes())) + decoder.UseNumber() + err = decoder.Decode(&jsonData) + if err != nil { + return fmt.Errorf("decode reqpacket failed, error: %+v", err) + } + { + jsonStr, _ := json.Marshal(jsonData["req"]) + req.ResetDefault() + if err = json.Unmarshal(jsonStr, &req); err != nil { + return err + } + } + + } else { + err = fmt.Errorf("decode reqpacket fail, error version: %d", tarsReq.IVersion) + return err + } + + traceData, ok := current.GetTraceData(tarsCtx) + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value["req"] = req + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSR), trace.TraceAnnotationSR, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "Reg", 0, traceParam, "") + } + + var funRet Result + if !withContext { + imp := val.(DemoObjServant) + funRet, err = imp.Reg(&req, &rsp) + } else { + imp := val.(DemoObjServantWithContext) + funRet, err = imp.Reg(tarsCtx, &req, &rsp) + } + + if err != nil { + return err + } + + if tarsReq.IVersion == basef.TARSVERSION { + buf.Reset() + + err = funRet.WriteBlock(buf, 0) + if err != nil { + return err + } + + err = rsp.WriteBlock(buf, 2) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.TUPVERSION { + rspTup := tup.NewUniAttribute() + + err = funRet.WriteBlock(buf, 0) + if err != nil { + return err + } + + rspTup.PutBuffer("", buf.ToBytes()) + rspTup.PutBuffer("tars_ret", buf.ToBytes()) + + buf.Reset() + err = rsp.WriteBlock(buf, 0) + if err != nil { + return err + } + + rspTup.PutBuffer("rsp", buf.ToBytes()) + + buf.Reset() + err = rspTup.Encode(buf) + if err != nil { + return err + } + } else if tarsReq.IVersion == basef.JSONVERSION { + rspJson := map[string]interface{}{} + rspJson["tars_ret"] = funRet + rspJson["rsp"] = rsp + + var rspByte []byte + if rspByte, err = json.Marshal(rspJson); err != nil { + return err + } + + buf.Reset() + err = buf.WriteSliceUint8(rspByte) + if err != nil { + return err + } + } + + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value[""] = funRet + value["rsp"] = rsp + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSS), trace.TraceAnnotationSS, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "Reg", 0, traceParam, "") + } + + case "Notify": + var notify Notify + + if tarsReq.IVersion == basef.TARSVERSION { + + err = notify.ReadBlock(readBuf, 1, true) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.TUPVERSION { + reqTup := tup.NewUniAttribute() + reqTup.Decode(readBuf) + + var tupBuffer []byte + + reqTup.GetBuffer("notify", &tupBuffer) + readBuf.Reset(tupBuffer) + err = notify.ReadBlock(readBuf, 0, true) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.JSONVERSION { + var jsonData map[string]interface{} + decoder := json.NewDecoder(bytes.NewReader(readBuf.ToBytes())) + decoder.UseNumber() + err = decoder.Decode(&jsonData) + if err != nil { + return fmt.Errorf("decode reqpacket failed, error: %+v", err) + } + { + jsonStr, _ := json.Marshal(jsonData["notify"]) + notify.ResetDefault() + if err = json.Unmarshal(jsonStr, ¬ify); err != nil { + return err + } + } + + } else { + err = fmt.Errorf("decode reqpacket fail, error version: %d", tarsReq.IVersion) + return err + } + + traceData, ok := current.GetTraceData(tarsCtx) + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value["notify"] = notify + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSR), trace.TraceAnnotationSR, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "Notify", 0, traceParam, "") + } + + var funRet Result + if !withContext { + imp := val.(DemoObjServant) + funRet, err = imp.Notify(¬ify) + } else { + imp := val.(DemoObjServantWithContext) + funRet, err = imp.Notify(tarsCtx, ¬ify) + } + + if err != nil { + return err + } + + if tarsReq.IVersion == basef.TARSVERSION { + buf.Reset() + + err = funRet.WriteBlock(buf, 0) + if err != nil { + return err + } + + } else if tarsReq.IVersion == basef.TUPVERSION { + rspTup := tup.NewUniAttribute() + + err = funRet.WriteBlock(buf, 0) + if err != nil { + return err + } + + rspTup.PutBuffer("", buf.ToBytes()) + rspTup.PutBuffer("tars_ret", buf.ToBytes()) + + buf.Reset() + err = rspTup.Encode(buf) + if err != nil { + return err + } + } else if tarsReq.IVersion == basef.JSONVERSION { + rspJson := map[string]interface{}{} + rspJson["tars_ret"] = funRet + + var rspByte []byte + if rspByte, err = json.Marshal(rspJson); err != nil { + return err + } + + buf.Reset() + err = buf.WriteSliceUint8(rspByte) + if err != nil { + return err + } + } + + if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} + value[""] = funRet + p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSS), trace.TraceAnnotationSS, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "Notify", 0, traceParam, "") + } + + default: + return fmt.Errorf("func mismatch") + } + var statusMap map[string]string + if status, ok := current.GetResponseStatus(tarsCtx); ok && status != nil { + statusMap = status + } + var contextMap map[string]string + if ctx, ok := current.GetResponseContext(tarsCtx); ok && ctx != nil { + contextMap = ctx + } + *tarsResp = requestf.ResponsePacket{ + IVersion: tarsReq.IVersion, + CPacketType: 0, + IRequestId: tarsReq.IRequestId, + IMessageType: 0, + IRet: 0, + SBuffer: tools.ByteToInt8(buf.ToBytes()), + Status: statusMap, + SResultDesc: "", + Context: contextMap, + } + + _ = readBuf + _ = buf + _ = length + _ = have + _ = ty + return nil +} + +type DemoObjTarsCallback interface { + Push_Callback(msg *string, opt ...map[string]string) + Push_ExceptionCallback(err error) + Reg_Callback(ret *Result, rsp *RegRsp, opt ...map[string]string) + Reg_ExceptionCallback(err error) + Notify_Callback(ret *Result, opt ...map[string]string) + Notify_ExceptionCallback(err error) +} + +// DemoObjPushCallback struct +type DemoObjPushCallback struct { + Cb DemoObjTarsCallback +} + +func (cb *DemoObjPushCallback) Ondispatch(resp *requestf.ResponsePacket) { + switch resp.SResultDesc { + case "Push": + err := func() error { + var err error + readBuf := codec.NewReader(tools.Int8ToByte(resp.SBuffer)) + var msg = new(string) + + err = readBuf.ReadString(&(*msg), 1, true) + if err != nil { + return err + } + + if resp.Context != nil { + cb.Cb.Push_Callback(msg, resp.Context) + return nil + } else { + cb.Cb.Push_Callback(msg) + return nil + } + }() + if err != nil { + cb.Cb.Push_ExceptionCallback(err) + } + case "Reg": + err := func() error { + var err error + readBuf := codec.NewReader(tools.Int8ToByte(resp.SBuffer)) + var ret = new(Result) + err = ret.ReadBlock(readBuf, 0, true) + if err != nil { + return err + } + + var rsp = new(RegRsp) + + err = (*rsp).ReadBlock(readBuf, 2, true) + if err != nil { + return err + } + + if resp.Context != nil { + cb.Cb.Reg_Callback(ret, rsp, resp.Context) + return nil + } else { + cb.Cb.Reg_Callback(ret, rsp) + return nil + } + }() + if err != nil { + cb.Cb.Reg_ExceptionCallback(err) + } + case "Notify": + err := func() error { + var err error + readBuf := codec.NewReader(tools.Int8ToByte(resp.SBuffer)) + var ret = new(Result) + err = ret.ReadBlock(readBuf, 0, true) + if err != nil { + return err + } + + if resp.Context != nil { + cb.Cb.Notify_Callback(ret, resp.Context) + return nil + } else { + cb.Cb.Notify_Callback(ret) + return nil + } + }() + if err != nil { + cb.Cb.Notify_ExceptionCallback(err) + } + } +} +func (obj *DemoObj) AsyncSendResponse_Push(ctx context.Context, msg *string, opt ...map[string]string) (err error) { + + conn, udpAddr, ok := current.GetRawConn(ctx) + if !ok { + return fmt.Errorf("connection not found") + } + buf := codec.NewBuffer() + + err = buf.WriteString(*msg, 1) + if err != nil { + return err + } + + resp := &requestf.ResponsePacket{ + SBuffer: tools.ByteToInt8(buf.ToBytes()), + } + resp.IVersion = basef.TARSVERSION + if resp.Status == nil { + resp.Status = make(map[string]string) + } + resp.Status["TARS_FUNC"] = "Push" + resp.SResultDesc = "Push" + if len(opt) > 0 { + if opt[0] != nil { + resp.Context = opt[0] + } + } + rspData := obj.rsp2Byte(resp) + if udpAddr != nil { + udpConn, _ := conn.(*net.UDPConn) + _, err = udpConn.WriteToUDP(rspData, udpAddr) + } else { + _, err = conn.Write(rspData) + } + return err +} +func (obj *DemoObj) AsyncSendResponse_Reg(ctx context.Context, ret *Result, rsp *RegRsp, opt ...map[string]string) (err error) { + + conn, udpAddr, ok := current.GetRawConn(ctx) + if !ok { + return fmt.Errorf("connection not found") + } + buf := codec.NewBuffer() + + err = ret.WriteBlock(buf, 0) + if err != nil { + return err + } + + err = (*rsp).WriteBlock(buf, 2) + if err != nil { + return err + } + + resp := &requestf.ResponsePacket{ + SBuffer: tools.ByteToInt8(buf.ToBytes()), + } + resp.IVersion = basef.TARSVERSION + if resp.Status == nil { + resp.Status = make(map[string]string) + } + resp.Status["TARS_FUNC"] = "Reg" + resp.SResultDesc = "Reg" + if len(opt) > 0 { + if opt[0] != nil { + resp.Context = opt[0] + } + } + rspData := obj.rsp2Byte(resp) + if udpAddr != nil { + udpConn, _ := conn.(*net.UDPConn) + _, err = udpConn.WriteToUDP(rspData, udpAddr) + } else { + _, err = conn.Write(rspData) + } + return err +} +func (obj *DemoObj) AsyncSendResponse_Notify(ctx context.Context, ret *Result, opt ...map[string]string) (err error) { + + conn, udpAddr, ok := current.GetRawConn(ctx) + if !ok { + return fmt.Errorf("connection not found") + } + buf := codec.NewBuffer() + + err = ret.WriteBlock(buf, 0) + if err != nil { + return err + } + + resp := &requestf.ResponsePacket{ + SBuffer: tools.ByteToInt8(buf.ToBytes()), + } + resp.IVersion = basef.TARSVERSION + if resp.Status == nil { + resp.Status = make(map[string]string) + } + resp.Status["TARS_FUNC"] = "Notify" + resp.SResultDesc = "Notify" + if len(opt) > 0 { + if opt[0] != nil { + resp.Context = opt[0] + } + } + rspData := obj.rsp2Byte(resp) + if udpAddr != nil { + udpConn, _ := conn.(*net.UDPConn) + _, err = udpConn.WriteToUDP(rspData, udpAddr) + } else { + _, err = conn.Write(rspData) + } + return err +} diff --git a/examples/TarsPushServer/demo/demo.go b/examples/TarsPushServer/demo/demo.go new file mode 100644 index 00000000..0a076ec2 --- /dev/null +++ b/examples/TarsPushServer/demo/demo.go @@ -0,0 +1,397 @@ +// Package demo comment +// This file was generated by tars2go 1.1.10 +// Generated from demo.tars +package demo + +import ( + "fmt" + + "github.com/TarsCloud/TarsGo/tars/protocol/codec" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = fmt.Errorf +var _ = codec.FromInt8 + +// Result struct implement +type Result struct { + Code int32 `json:"code"` + Msg string `json:"msg"` +} + +func (st *Result) ResetDefault() { +} + +// ReadFrom reads from readBuf and put into struct. +func (st *Result) ReadFrom(readBuf *codec.Reader) error { + var ( + err error + length int32 + have bool + ty byte + ) + st.ResetDefault() + + err = readBuf.ReadInt32(&st.Code, 0, false) + if err != nil { + return err + } + + err = readBuf.ReadString(&st.Msg, 1, false) + if err != nil { + return err + } + + _ = err + _ = length + _ = have + _ = ty + return nil +} + +// ReadBlock reads struct from the given tag , require or optional. +func (st *Result) ReadBlock(readBuf *codec.Reader, tag byte, require bool) error { + var ( + err error + have bool + ) + st.ResetDefault() + + have, err = readBuf.SkipTo(codec.StructBegin, tag, require) + if err != nil { + return err + } + if !have { + if require { + return fmt.Errorf("require Result, but not exist. tag %d", tag) + } + return nil + } + + err = st.ReadFrom(readBuf) + if err != nil { + return err + } + + err = readBuf.SkipToStructEnd() + if err != nil { + return err + } + _ = have + return nil +} + +// WriteTo encode struct to buffer +func (st *Result) WriteTo(buf *codec.Buffer) (err error) { + + err = buf.WriteInt32(st.Code, 0) + if err != nil { + return err + } + + err = buf.WriteString(st.Msg, 1) + if err != nil { + return err + } + + return err +} + +// WriteBlock encode struct +func (st *Result) WriteBlock(buf *codec.Buffer, tag byte) error { + var err error + err = buf.WriteHead(codec.StructBegin, tag) + if err != nil { + return err + } + + err = st.WriteTo(buf) + if err != nil { + return err + } + + err = buf.WriteHead(codec.StructEnd, 0) + if err != nil { + return err + } + return nil +} + +// RegReq struct implement +type RegReq struct { + Msg string `json:"msg"` +} + +func (st *RegReq) ResetDefault() { +} + +// ReadFrom reads from readBuf and put into struct. +func (st *RegReq) ReadFrom(readBuf *codec.Reader) error { + var ( + err error + length int32 + have bool + ty byte + ) + st.ResetDefault() + + err = readBuf.ReadString(&st.Msg, 0, false) + if err != nil { + return err + } + + _ = err + _ = length + _ = have + _ = ty + return nil +} + +// ReadBlock reads struct from the given tag , require or optional. +func (st *RegReq) ReadBlock(readBuf *codec.Reader, tag byte, require bool) error { + var ( + err error + have bool + ) + st.ResetDefault() + + have, err = readBuf.SkipTo(codec.StructBegin, tag, require) + if err != nil { + return err + } + if !have { + if require { + return fmt.Errorf("require RegReq, but not exist. tag %d", tag) + } + return nil + } + + err = st.ReadFrom(readBuf) + if err != nil { + return err + } + + err = readBuf.SkipToStructEnd() + if err != nil { + return err + } + _ = have + return nil +} + +// WriteTo encode struct to buffer +func (st *RegReq) WriteTo(buf *codec.Buffer) (err error) { + + err = buf.WriteString(st.Msg, 0) + if err != nil { + return err + } + + return err +} + +// WriteBlock encode struct +func (st *RegReq) WriteBlock(buf *codec.Buffer, tag byte) error { + var err error + err = buf.WriteHead(codec.StructBegin, tag) + if err != nil { + return err + } + + err = st.WriteTo(buf) + if err != nil { + return err + } + + err = buf.WriteHead(codec.StructEnd, 0) + if err != nil { + return err + } + return nil +} + +// RegRsp struct implement +type RegRsp struct { + Msg string `json:"msg"` +} + +func (st *RegRsp) ResetDefault() { +} + +// ReadFrom reads from readBuf and put into struct. +func (st *RegRsp) ReadFrom(readBuf *codec.Reader) error { + var ( + err error + length int32 + have bool + ty byte + ) + st.ResetDefault() + + err = readBuf.ReadString(&st.Msg, 0, false) + if err != nil { + return err + } + + _ = err + _ = length + _ = have + _ = ty + return nil +} + +// ReadBlock reads struct from the given tag , require or optional. +func (st *RegRsp) ReadBlock(readBuf *codec.Reader, tag byte, require bool) error { + var ( + err error + have bool + ) + st.ResetDefault() + + have, err = readBuf.SkipTo(codec.StructBegin, tag, require) + if err != nil { + return err + } + if !have { + if require { + return fmt.Errorf("require RegRsp, but not exist. tag %d", tag) + } + return nil + } + + err = st.ReadFrom(readBuf) + if err != nil { + return err + } + + err = readBuf.SkipToStructEnd() + if err != nil { + return err + } + _ = have + return nil +} + +// WriteTo encode struct to buffer +func (st *RegRsp) WriteTo(buf *codec.Buffer) (err error) { + + err = buf.WriteString(st.Msg, 0) + if err != nil { + return err + } + + return err +} + +// WriteBlock encode struct +func (st *RegRsp) WriteBlock(buf *codec.Buffer, tag byte) error { + var err error + err = buf.WriteHead(codec.StructBegin, tag) + if err != nil { + return err + } + + err = st.WriteTo(buf) + if err != nil { + return err + } + + err = buf.WriteHead(codec.StructEnd, 0) + if err != nil { + return err + } + return nil +} + +// Notify struct implement +type Notify struct { + Msg string `json:"msg"` +} + +func (st *Notify) ResetDefault() { +} + +// ReadFrom reads from readBuf and put into struct. +func (st *Notify) ReadFrom(readBuf *codec.Reader) error { + var ( + err error + length int32 + have bool + ty byte + ) + st.ResetDefault() + + err = readBuf.ReadString(&st.Msg, 0, false) + if err != nil { + return err + } + + _ = err + _ = length + _ = have + _ = ty + return nil +} + +// ReadBlock reads struct from the given tag , require or optional. +func (st *Notify) ReadBlock(readBuf *codec.Reader, tag byte, require bool) error { + var ( + err error + have bool + ) + st.ResetDefault() + + have, err = readBuf.SkipTo(codec.StructBegin, tag, require) + if err != nil { + return err + } + if !have { + if require { + return fmt.Errorf("require Notify, but not exist. tag %d", tag) + } + return nil + } + + err = st.ReadFrom(readBuf) + if err != nil { + return err + } + + err = readBuf.SkipToStructEnd() + if err != nil { + return err + } + _ = have + return nil +} + +// WriteTo encode struct to buffer +func (st *Notify) WriteTo(buf *codec.Buffer) (err error) { + + err = buf.WriteString(st.Msg, 0) + if err != nil { + return err + } + + return err +} + +// WriteBlock encode struct +func (st *Notify) WriteBlock(buf *codec.Buffer, tag byte) error { + var err error + err = buf.WriteHead(codec.StructBegin, tag) + if err != nil { + return err + } + + err = st.WriteTo(buf) + if err != nil { + return err + } + + err = buf.WriteHead(codec.StructEnd, 0) + if err != nil { + return err + } + return nil +} diff --git a/examples/TarsPushServer/server/DemoServer.go b/examples/TarsPushServer/server/DemoServer.go new file mode 100644 index 00000000..e4a7a058 --- /dev/null +++ b/examples/TarsPushServer/server/DemoServer.go @@ -0,0 +1,29 @@ +package main + +import ( + "TarPushServer/server/Impl" + "flag" + "fmt" + "github.com/TarsCloud/TarsGo/tars" +) + +func main() { + SrvConfig := "" // tars服务配置 + + flag.StringVar(&SrvConfig, "config", "", "path to server config file") + flag.Parse() + + fmt.Println("srv_conf: ", SrvConfig) + + // 这里赋值为了避免tars解析命令行参数导致二次解析报错 + tars.ServerConfigPath = SrvConfig + + cfg := tars.GetServerConfig() + imp := &Impl.DemoImp{} + Impl.GetApp().AddServantWithContext(imp, cfg.App+"."+cfg.Server+".DemoObj") + + //klog.Init(cfg.App, cfg.Server, cfg.LogPath, int(cfg.LogNum)) + + //klog.CONSOLE.Info("server running...") + tars.Run() +} diff --git a/examples/TarsPushServer/server/Impl/DemoImp.go b/examples/TarsPushServer/server/Impl/DemoImp.go new file mode 100644 index 00000000..d1624312 --- /dev/null +++ b/examples/TarsPushServer/server/Impl/DemoImp.go @@ -0,0 +1,58 @@ +package Impl + +import ( + "TarPushServer/demo" + "context" + "fmt" + "github.com/TarsCloud/TarsGo/tars/util/current" + "strconv" +) + +var app = &demo.DemoObj{} + +type DemoImp struct { +} + +func (d DemoImp) Notify(tarsCtx context.Context, notify *demo.Notify) (ret demo.Result, err error) { + //TODO implement me + panic("implement me") +} + +func GetApp() *demo.DemoObj { + return app +} +func (d DemoImp) Reg(tarsCtx context.Context, req *demo.RegReq, rsp *demo.RegRsp) (ret demo.Result, err error) { + rsp.Msg = req.Msg + go func() { + for i := 0; i < 10; i++ { + msg := fmt.Sprintf("push msg %d", i) + context := make(map[string]string, 1) + context["msg"] = "******" + strconv.Itoa(i) + uuid, _ := current.GetUUID(tarsCtx) + context["uuid"] = uuid + GetApp().AsyncSendResponse_Push(tarsCtx, &msg, context) + } + }() + return demo.Result{}, nil +} + +func (d DemoImp) Push(ctx context.Context, msg *string) (err error) { + return nil +} + +func (d DemoImp) Invoke(ctx context.Context, pkg []byte) []byte { + //TODO implement me + fmt.Println("implement me") + return []byte{} +} + +func (d DemoImp) GetCloseMsg() []byte { + //TODO implement me + fmt.Println("implement me") + return nil +} + +func (d DemoImp) DoClose(ctx context.Context) { + //TODO implement me + fmt.Println("implement me") +} diff --git a/examples/TarsPushServer/server/conf/Base.DemoServer.servant.conf b/examples/TarsPushServer/server/conf/Base.DemoServer.servant.conf new file mode 100644 index 00000000..9cec252b --- /dev/null +++ b/examples/TarsPushServer/server/conf/Base.DemoServer.servant.conf @@ -0,0 +1,37 @@ + + + enableset=n + setdivision=NULL + + app=Base + server=DemoServer + basepath=./ + datapath=./ + logpath=./ + logsize=10M + lognum=10 + logLevel=DEBUG + + allow + endpoint=tcp -h 0.0.0.0 -p 8888 -t 60000 + maxconns=100000 + protocol=tars + queuecap=500000 + queuetimeout=20000 + servant=Base.DemoServer.DemoObj + threads=1 + + + + locator=tars.tarsregistry.QueryObj@tcp -h 10.253.50.57 -p 17890 + sync-invoke-timeout=3000 + async-invoke-timeout=5000 + refresh-endpoint-interval=60000 + stat=tars.tarsstat.StatObj + property=tars.tarsproperty.PropertyObj + report-interval=60000 + asyncthread=3 + modulename=Base.DemoServer + + + diff --git a/examples/trace/TarsTraceBackServer/go.sum b/examples/trace/TarsTraceBackServer/go.sum index 06bfcefb..701cc061 100644 --- a/examples/trace/TarsTraceBackServer/go.sum +++ b/examples/trace/TarsTraceBackServer/go.sum @@ -164,6 +164,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= @@ -176,6 +177,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/automaxprocs v1.5.1 h1:e1YG66Lrk73dn4qhg8WFSvhF0JuFQF0ERIp4rpuV8Qk= go.uber.org/automaxprocs v1.5.1/go.mod h1:BF4eumQw0P9GtnuxxovUd06vwm1o18oMzFtK66vU6XU= +go.uber.org/automaxprocs v1.5.2/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/examples/trace/TarsTraceFrontServer/go.sum b/examples/trace/TarsTraceFrontServer/go.sum index 06bfcefb..701cc061 100644 --- a/examples/trace/TarsTraceFrontServer/go.sum +++ b/examples/trace/TarsTraceFrontServer/go.sum @@ -164,6 +164,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= @@ -176,6 +177,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/automaxprocs v1.5.1 h1:e1YG66Lrk73dn4qhg8WFSvhF0JuFQF0ERIp4rpuV8Qk= go.uber.org/automaxprocs v1.5.1/go.mod h1:BF4eumQw0P9GtnuxxovUd06vwm1o18oMzFtK66vU6XU= +go.uber.org/automaxprocs v1.5.2/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/go.mod b/go.mod index b7c8fb0a..09f1f5b9 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/TarsCloud/TarsGo go 1.14 require ( + github.com/google/uuid v1.6.0 github.com/kr/pretty v0.3.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index c3350db6..6d3d81ef 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= diff --git a/tars/adapter.go b/tars/adapter.go index 843a4ca3..65604770 100755 --- a/tars/adapter.go +++ b/tars/adapter.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + "github.com/TarsCloud/TarsGo/tars/model" "github.com/TarsCloud/TarsGo/tars/protocol/res/basef" "github.com/TarsCloud/TarsGo/tars/protocol/res/endpointf" "github.com/TarsCloud/TarsGo/tars/protocol/res/requestf" @@ -35,13 +36,17 @@ type AdapterProxy struct { lastKeepAliveTime int64 pushCallback func([]byte) onceKeepAlive sync.Once - - closed bool + onDispatch model.Ondispatch + asyncPushCh chan *requestf.ResponsePacket // + closed bool } // NewAdapterProxy create an adapter proxy func NewAdapterProxy(objName string, point *endpointf.EndpointF, comm *Communicator) *AdapterProxy { c := &AdapterProxy{} + if c.asyncPushCh == nil { + c.asyncPushCh = make(chan *requestf.ResponsePacket) + } c.comm = comm c.point = point proto := "tcp" @@ -68,6 +73,16 @@ func NewAdapterProxy(objName string, point *endpointf.EndpointF, comm *Communica c.conf = conf c.tarsClient = transport.NewTarsClient(fmt.Sprintf("%s:%d", point.Host, point.Port), c, conf) c.status = true + + // push queue listenning + go func() { + for { + select { + case resp := <-c.asyncPushCh: + c.onPush(resp) + } + } + }() return c } @@ -91,7 +106,7 @@ func (c *AdapterProxy) Recv(pkg []byte) { return } if packet.IRequestId == 0 { - c.onPush(packet) + c.asyncPushCh <- packet return } if packet.CPacketType == basef.TARSONEWAY { @@ -214,12 +229,23 @@ func (c *AdapterProxy) onPush(pkg *requestf.ResponsePacket) { oldClient.GraceClose(ctx) // grace shutdown return } - // Support push msg - if c.pushCallback == nil { - return + + switch pkg.IVersion { + case 0: + // raw socket push + if c.pushCallback == nil { + return + } + data := tools.Int8ToByte(pkg.SBuffer) + c.pushCallback(data) + + case basef.TARSVERSION: + if c.onDispatch == nil { + return + } + // tars proto push + c.onDispatch.Ondispatch(pkg) } - data := tools.Int8ToByte(pkg.SBuffer) - c.pushCallback(data) } func (c *AdapterProxy) autoKeepAlive() { @@ -235,6 +261,18 @@ func (c *AdapterProxy) autoKeepAlive() { } } +func (c *AdapterProxy) OnConnect(address string) { + if c.servantProxy.onConnectCallback != nil { + c.servantProxy.onConnectCallback(address) + } +} + +func (c *AdapterProxy) OnClose(address string) { + if c.servantProxy.onCloseCallback != nil { + c.servantProxy.onCloseCallback(address) + } +} + func (c *AdapterProxy) doKeepAlive() { if c.closed { return diff --git a/tars/model/servant.go b/tars/model/servant.go index 22e43f68..bcb94524 100755 --- a/tars/model/servant.go +++ b/tars/model/servant.go @@ -8,6 +8,10 @@ import ( "github.com/TarsCloud/TarsGo/tars/protocol/res/requestf" ) +type Ondispatch interface { + Ondispatch(resp *requestf.ResponsePacket) +} + // Servant is interface for call the remote server. type Servant interface { Name() string @@ -19,8 +23,12 @@ type Servant interface { resp *requestf.ResponsePacket) error TarsSetTimeout(t int) TarsSetProtocol(Protocol) + TarsPing(ctx context.Context) Endpoints() []*endpoint.Endpoint SetPushCallback(callback func([]byte)) + SetTarsCallback(callback Ondispatch) + SetOnCloseCallback(callback func(string)) + SetOnConnectCallback(callback func(string)) } type Protocol interface { diff --git a/tars/protocol/codec/codec.go b/tars/protocol/codec/codec.go index 86460de6..6c48c215 100755 --- a/tars/protocol/codec/codec.go +++ b/tars/protocol/codec/codec.go @@ -385,7 +385,6 @@ func (b *Reader) unreadHead(curTag byte) { } // Next return the []byte of next n . -// //go:nosplit func (b *Reader) Next(n int) []byte { if n <= 0 { @@ -398,7 +397,6 @@ func (b *Reader) Next(n int) []byte { } // Skip the next n byte. -// //go:nosplit func (b *Reader) Skip(n int) { if n <= 0 { diff --git a/tars/protocol/res/basef/BaseF.go b/tars/protocol/res/basef/BaseF.go index d12d746d..becf8add 100644 --- a/tars/protocol/res/basef/BaseF.go +++ b/tars/protocol/res/basef/BaseF.go @@ -13,7 +13,7 @@ import ( var _ = fmt.Errorf var _ = codec.FromInt8 -//const as define in tars file +// const as define in tars file const ( TARSVERSION int16 = 0x01 TUPVERSION int16 = 0x03 diff --git a/tars/servant.go b/tars/servant.go index 8911db4a..46300e02 100755 --- a/tars/servant.go +++ b/tars/servant.go @@ -30,15 +30,17 @@ const ( // ServantProxy tars servant proxy instance type ServantProxy struct { - name string - comm *Communicator - manager EndpointManager - timeout int - version int16 - proto model.Protocol - queueLen int32 - - pushCallback func([]byte) + name string + comm *Communicator + manager EndpointManager + timeout int + version int16 + proto model.Protocol + queueLen int32 + ondispatch model.Ondispatch + pushCallback func([]byte) + onCloseCallback func(string) + onConnectCallback func(string) } // NewServantProxy creates and initializes a servant proxy @@ -107,11 +109,44 @@ func (s *ServantProxy) genRequestID() int32 { } } +func (s *ServantProxy) TarsPing(ctx context.Context) { + req := requestf.RequestPacket{ + IVersion: s.version, + CPacketType: basef.TARSONEWAY, + IRequestId: s.genRequestID(), + SServantName: s.name, + SFuncName: "tars_ping", + ITimeout: int32(s.timeout), + } + msg := &Message{Req: &req, Ser: s} + msg.Init() + timeout := time.Duration(s.timeout) * time.Millisecond + s.manager.preInvoke() + err := s.doInvoke(ctx, msg, timeout) + s.manager.postInvoke() + if err != nil { + TLOG.Errorf("KsfPing err: %v", err) + } + msg.End() +} + +func (s *ServantProxy) SetTarsCallback(ondispatch model.Ondispatch) { + s.ondispatch = ondispatch +} + // SetPushCallback set callback function for pushing func (s *ServantProxy) SetPushCallback(callback func([]byte)) { s.pushCallback = callback } +func (s *ServantProxy) SetOnConnectCallback(callback func(string)) { + s.onConnectCallback = callback +} + +func (s *ServantProxy) SetOnCloseCallback(callback func(string)) { + s.onCloseCallback = callback +} + // TarsInvoke is used for client invoking server. func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte, sFuncName string, @@ -244,6 +279,11 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time. adp.pushCallback = s.pushCallback } + if s.ondispatch != nil { + go adp.onceKeepAlive.Do(adp.autoKeepAlive) + adp.onDispatch = s.ondispatch + } + atomic.AddInt32(&s.queueLen, 1) readCh := make(chan *requestf.ResponsePacket) adp.resp.Store(msg.Req.IRequestId, readCh) diff --git a/tars/tools/tars2go/gen_go.go b/tars/tools/tars2go/gen_go.go new file mode 100755 index 00000000..7de4876a --- /dev/null +++ b/tars/tools/tars2go/gen_go.go @@ -0,0 +1,2111 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" +) + +var gE = flag.Bool("E", false, "Generate code before fmt for troubleshooting") +var gAddServant = flag.Bool("add-servant", true, "Generate AddServant function") +var gModuleCycle = flag.Bool("module-cycle", false, "support jce module cycle include(do not support jce file cycle include)") +var gModuleUpper = flag.Bool("module-upper", false, "native module names are supported, otherwise the system will upper the first letter of the module name") +var gJsonOmitEmpty = flag.Bool("json-omitempty", false, "Generate json omitempty support") +var dispatchReporter = flag.Bool("dispatch-reporter", false, "Dispatch reporter support") +var debug = flag.Bool("debug", false, "enable debug mode") + +var gFileMap map[string]bool + +func init() { + gFileMap = make(map[string]bool) +} + +//GenGo record go code information. +type GenGo struct { + I []string // imports with path + code bytes.Buffer + vc int // var count. Used to generate unique variable names + path string + tarsPath string + module string + prefix string + p *Parse + + // proto file name(not include .tars) + ProtoName string +} + +//NewGenGo build up a new path +func NewGenGo(path string, module string, outdir string) *GenGo { + if outdir != "" { + b := []byte(outdir) + last := b[len(b)-1:] + if string(last) != "/" { + outdir += "/" + } + } + + return &GenGo{path: path, module: module, prefix: outdir, ProtoName: path2ProtoName(path)} +} + +func path2ProtoName(path string) string { + iBegin := strings.LastIndex(path, "/") + if iBegin == -1 || iBegin >= len(path)-1 { + iBegin = 0 + } else { + iBegin++ + } + iEnd := strings.LastIndex(path, ".tars") + if iEnd == -1 { + iEnd = len(path) + } + + return path[iBegin:iEnd] +} + +//Initial capitalization +func upperFirstLetter(s string) string { + if len(s) == 0 { + return "" + } + if len(s) == 1 { + return strings.ToUpper(string(s[0])) + } + return strings.ToUpper(string(s[0])) + s[1:] +} + +func getShortTypeName(src string) string { + vec := strings.Split(src, "::") + return vec[len(vec)-1] +} + +func errString(hasRet bool) string { + var retStr string + if hasRet { + retStr = "return ret, err" + } else { + retStr = "return err" + } + return `if err != nil { + ` + retStr + ` + }` + "\n" +} + +func genForHead(vc string) string { + i := `i` + vc + e := `e` + vc + return ` for ` + i + `,` + e + ` := int32(0), length;` + i + `<` + e + `;` + i + `++ ` +} + +// === rename area === +// 0. rename module +func (p *Parse) rename() { + p.OriginModule = p.Module + if *gModuleUpper { + p.Module = upperFirstLetter(p.Module) + } +} + +// 1. struct rename +// struct Name { 1 require Mb type} +func (st *StructInfo) rename() { + st.OriginName = st.Name + st.Name = upperFirstLetter(st.Name) + for i := range st.Mb { + st.Mb[i].OriginKey = st.Mb[i].Key + st.Mb[i].Key = upperFirstLetter(st.Mb[i].Key) + } +} + +// 1. interface rename +// interface Name { Fun } +func (itf *InterfaceInfo) rename() { + itf.OriginName = itf.Name + itf.Name = upperFirstLetter(itf.Name) + for i := range itf.Fun { + itf.Fun[i].rename() + } +} + +func (en *EnumInfo) rename() { + en.OriginName = en.Name + en.Name = upperFirstLetter(en.Name) + for i := range en.Mb { + en.Mb[i].Key = upperFirstLetter(en.Mb[i].Key) + } +} + +func (cst *ConstInfo) rename() { + cst.OriginName = cst.Name + cst.Name = upperFirstLetter(cst.Name) +} + +// 2. func rename +// type Fun (arg ArgType), in case keyword and name conflicts,argname need to capitalize. +// Fun (type int32) +func (fun *FunInfo) rename() { + fun.OriginName = fun.Name + fun.Name = upperFirstLetter(fun.Name) + for i := range fun.Args { + fun.Args[i].OriginName = fun.Args[i].Name + // func args donot upper firs + //fun.Args[i].Name = upperFirstLetter(fun.Args[i].Name) + } +} + +// 3. genType rename all Type + +// === rename end === + +// Gen to parse file. +func (gen *GenGo) Gen() { + defer func() { + if err := recover(); err != nil { + fmt.Println(err) + // set exit code + os.Exit(1) + } + }() + + gen.p = ParseFile(gen.path, make([]string, 0)) + gen.genAll() +} + +func (gen *GenGo) genAll() { + if gFileMap[gen.path] { + // already compiled + return + } + gFileMap[gen.path] = true + + gen.p.rename() + gen.genInclude(gen.p.IncParse) + + gen.code.Reset() + gen.genHead() + gen.genPackage() + + for _, v := range gen.p.Enum { + gen.genEnum(&v) + } + + gen.genConst(gen.p.Const) + + for _, v := range gen.p.Struct { + gen.genStruct(&v) + } + if len(gen.p.Enum) > 0 || len(gen.p.Const) > 0 || len(gen.p.Struct) > 0 { + gen.saveToSourceFile(path2ProtoName(gen.path) + ".go") + } + + for _, v := range gen.p.Interface { + gen.genInterface(&v) + } +} + +func (gen *GenGo) genErr(err string) { + panic(err) +} + +func (gen *GenGo) saveToSourceFile(filename string) { + var beauty []byte + var err error + prefix := gen.prefix + + if !*gE { + beauty, err = format.Source(gen.code.Bytes()) + if err != nil { + if *debug { + fmt.Println("------------------") + fmt.Println(string(gen.code.Bytes())) + fmt.Println("------------------") + } + gen.genErr("go fmt fail. " + filename + " " + err.Error()) + } + } else { + beauty = gen.code.Bytes() + } + + if filename == "stdout" { + fmt.Println(string(beauty)) + } else { + var mkPath string + if *gModuleCycle == true { + mkPath = prefix + gen.ProtoName + "/" + gen.p.Module + } else { + mkPath = prefix + gen.p.Module + } + err = os.MkdirAll(mkPath, 0766) + + if err != nil { + gen.genErr(err.Error()) + } + err = ioutil.WriteFile(mkPath+"/"+filename, beauty, 0666) + + if err != nil { + gen.genErr(err.Error()) + } + } +} + +func (gen *GenGo) genVariableName(prefix, name string) string { + if strings.HasPrefix(name, "(*") && strings.HasSuffix(name, ")") { + return strings.Trim(name, "()") + } + return prefix + name +} + +func (gen *GenGo) genHead() { + gen.code.WriteString(`// Package ` + gen.p.Module + ` comment +// This file was generated by tars2go ` + VERSION + ` +// Generated from ` + filepath.Base(gen.path) + ` +`) +} + +func (gen *GenGo) genPackage() { + gen.code.WriteString("package " + gen.p.Module + "\n\n") + gen.code.WriteString(` +import ( + "fmt" + +`) + gen.code.WriteString("\"" + gen.tarsPath + "/protocol/codec\"\n") + + mImports := make(map[string]bool) + for _, st := range gen.p.Struct { + if *gModuleCycle == true { + for k, v := range st.DependModuleWithJce { + gen.genStructImport(k, v, mImports) + } + } else { + for k := range st.DependModule { + gen.genStructImport(k, "", mImports) + } + } + } + for path := range mImports { + gen.code.WriteString(path + "\n") + } + + gen.code.WriteString(`) + + // Reference imports to suppress errors if they are not otherwise used. + var _ = fmt.Errorf + var _ = codec.FromInt8 + +`) +} + +func (gen *GenGo) genStructImport(module string, protoName string, mImports map[string]bool) { + var moduleStr string + var jcePath string + var moduleAlia string + if *gModuleCycle == true { + moduleStr = module[len(protoName)+1:] + jcePath = protoName + "/" + moduleAlia = module + " " + } else { + moduleStr = module + } + + for _, p := range gen.I { + if strings.HasSuffix(p, "/"+moduleStr) { + mImports[`"`+p+`"`] = true + return + } + } + + if *gModuleUpper { + moduleAlia = upperFirstLetter(moduleAlia) + } + + // example: + // TarsTest.tars, MyApp + // gomod: + // github.com/xxx/yyy/tars-protocol/MyApp + // github.com/xxx/yyy/tars-protocol/TarsTest/MyApp + // + // gopath: + // MyApp + // TarsTest/MyApp + var modulePath string + if gen.module != "" { + mf := filepath.Clean(filepath.Join(gen.module, gen.prefix)) + if runtime.GOOS == "windows" { + mf = strings.ReplaceAll(mf, string(os.PathSeparator), string('/')) + } + modulePath = fmt.Sprintf("%s/%s%s", mf, jcePath, moduleStr) + } else { + modulePath = fmt.Sprintf("%s%s", jcePath, moduleStr) + } + mImports[moduleAlia+`"`+modulePath+`"`] = true +} + +func (gen *GenGo) genIFPackage(itf *InterfaceInfo) { + gen.code.WriteString("package " + gen.p.Module + "\n\n") + gen.code.WriteString(` +import ( + "bytes" + "context" + "fmt" + "unsafe" + "net" + "encoding/json" + "encoding/binary" +`) + if *gAddServant { + gen.code.WriteString("\"" + gen.tarsPath + "\"\n") + } + + gen.code.WriteString("\"" + gen.tarsPath + "/protocol/res/requestf\"\n") + gen.code.WriteString("m \"" + gen.tarsPath + "/model\"\n") + gen.code.WriteString("\"" + gen.tarsPath + "/protocol/codec\"\n") + gen.code.WriteString("\"" + gen.tarsPath + "/protocol/tup\"\n") + gen.code.WriteString("\"" + gen.tarsPath + "/protocol/res/basef\"\n") + gen.code.WriteString("\"" + gen.tarsPath + "/util/tools\"\n") + if !withoutTrace { + gen.code.WriteString("\"" + gen.tarsPath + "/util/trace\"\n") + } + gen.code.WriteString("\"" + gen.tarsPath + "/util/current\"\n") + + if *gModuleCycle == true { + for k, v := range itf.DependModuleWithJce { + gen.genIFImport(k, v) + } + } else { + for k := range itf.DependModule { + gen.genIFImport(k, "") + } + } + gen.code.WriteString(`) + + // Reference imports to suppress errors if they are not otherwise used. + var ( + _ = fmt.Errorf + _ = codec.FromInt8 + _ = unsafe.Pointer(nil) + _ = bytes.ErrTooLarge + ) +`) +} + +func (gen *GenGo) genIFImport(module string, protoName string) { + var moduleStr string + var jcePath string + var moduleAlia string + if *gModuleCycle == true { + moduleStr = module[len(protoName)+1:] + jcePath = protoName + "/" + moduleAlia = module + " " + } else { + moduleStr = module + } + for _, p := range gen.I { + if strings.HasSuffix(p, "/"+moduleStr) { + gen.code.WriteString(`"` + p + `"` + "\n") + return + } + } + + if *gModuleUpper { + moduleAlia = upperFirstLetter(moduleAlia) + } + + // example: + // TarsTest.tars, MyApp + // gomod: + // github.com/xxx/yyy/tars-protocol/MyApp + // github.com/xxx/yyy/tars-protocol/TarsTest/MyApp + // + // gopath: + // MyApp + // TarsTest/MyApp + var modulePath string + if gen.module != "" { + mf := filepath.Clean(filepath.Join(gen.module, gen.prefix)) + if runtime.GOOS == "windows" { + mf = strings.ReplaceAll(mf, string(os.PathSeparator), string('/')) + } + modulePath = fmt.Sprintf("%s/%s%s", mf, jcePath, moduleStr) + } else { + modulePath = fmt.Sprintf("%s%s", jcePath, moduleStr) + } + gen.code.WriteString(moduleAlia + `"` + modulePath + `"` + "\n") +} + +func (gen *GenGo) genType(ty *VarType) string { + ret := "" + switch ty.Type { + case tkTBool: + ret = "bool" + case tkTInt: + if ty.Unsigned { + ret = "uint32" + } else { + ret = "int32" + } + case tkTShort: + if ty.Unsigned { + ret = "uint16" + } else { + ret = "int16" + } + case tkTByte: + if ty.Unsigned { + ret = "uint8" + } else { + ret = "int8" + } + case tkTLong: + if ty.Unsigned { + ret = "uint64" + } else { + ret = "int64" + } + case tkTFloat: + ret = "float32" + case tkTDouble: + ret = "float64" + case tkTString: + ret = "string" + case tkTVector: + ret = "[]" + gen.genType(ty.TypeK) + case tkTMap: + ret = "map[" + gen.genType(ty.TypeK) + "]" + gen.genType(ty.TypeV) + case tkName: + ret = strings.Replace(ty.TypeSt, "::", ".", -1) + vec := strings.Split(ty.TypeSt, "::") + for i := range vec { + if *gModuleUpper { + vec[i] = upperFirstLetter(vec[i]) + } else { + if i == (len(vec) - 1) { + vec[i] = upperFirstLetter(vec[i]) + } + } + } + ret = strings.Join(vec, ".") + case tkTArray: + ret = "[" + fmt.Sprintf("%v", ty.TypeL) + "]" + gen.genType(ty.TypeK) + default: + gen.genErr("Unknown Type " + TokenMap[ty.Type]) + } + return ret +} + +func (gen *GenGo) genStructDefine(st *StructInfo) { + c := &gen.code + c.WriteString("// " + st.Name + " struct implement\n") + c.WriteString("type " + st.Name + " struct {\n") + + for _, v := range st.Mb { + if *gJsonOmitEmpty { + c.WriteString("\t" + v.Key + " " + gen.genType(v.Type) + " `json:\"" + v.OriginKey + ",omitempty\"`\n") + } else { + c.WriteString("\t" + v.Key + " " + gen.genType(v.Type) + " `json:\"" + v.OriginKey + "\"`\n") + } + } + c.WriteString("}\n") +} + +func (gen *GenGo) genFunResetDefault(st *StructInfo) { + c := &gen.code + + c.WriteString("func (st *" + st.Name + ") ResetDefault() {\n") + + for _, v := range st.Mb { + if v.Type.CType == tkStruct { + c.WriteString("st." + v.Key + ".ResetDefault()\n") + } + if v.Default == "" { + continue + } + c.WriteString("st." + v.Key + " = " + v.Default + "\n") + } + c.WriteString("}\n") +} + +func (gen *GenGo) genWriteSimpleList(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + tag := strconv.Itoa(int(mb.Tag)) + unsign := "Int8" + if mb.Type.TypeK.Unsigned { + unsign = "Uint8" + } + errStr := errString(hasRet) + c.WriteString(` +err = buf.WriteHead(codec.SimpleList, ` + tag + `) +` + errStr + ` +err = buf.WriteHead(codec.BYTE, 0) +` + errStr + ` +err = buf.WriteInt32(int32(len(` + gen.genVariableName(prefix, mb.Key) + `)), 0) +` + errStr + ` +err = buf.WriteSlice` + unsign + `(` + gen.genVariableName(prefix, mb.Key) + `) +` + errStr + ` +`) +} + +func (gen *GenGo) genWriteVector(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + + // SimpleList + if mb.Type.TypeK.Type == tkTByte && !mb.Type.TypeK.Unsigned { + gen.genWriteSimpleList(mb, prefix, hasRet) + return + } + errStr := errString(hasRet) + + // LIST + tag := strconv.Itoa(int(mb.Tag)) + c.WriteString(` +err = buf.WriteHead(codec.LIST, ` + tag + `) +` + errStr + ` +err = buf.WriteInt32(int32(len(` + gen.genVariableName(prefix, mb.Key) + `)), 0) +` + errStr + ` +for _, v := range ` + gen.genVariableName(prefix, mb.Key) + ` { +`) + // for _, v := range can nesting for _, v := range,does not conflict, support multidimensional arrays + + dummy := &StructMember{} + dummy.Type = mb.Type.TypeK + dummy.Key = "v" + gen.genWriteVar(dummy, "", hasRet) + + c.WriteString("}\n") +} + +func (gen *GenGo) genWriteArray(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + + // SimpleList + if mb.Type.TypeK.Type == tkTByte && !mb.Type.TypeK.Unsigned { + gen.genWriteSimpleList(mb, prefix, hasRet) + return + } + errStr := errString(hasRet) + + // LIST + tag := strconv.Itoa(int(mb.Tag)) + c.WriteString(` +err = buf.WriteHead(codec.LIST, ` + tag + `) +` + errStr + ` +err = buf.WriteInt32(int32(len(` + gen.genVariableName(prefix, mb.Key) + `)), 0) +` + errStr + ` +for _, v := range ` + gen.genVariableName(prefix, mb.Key) + ` { +`) + // for _, v := range can nesting for _, v := range,does not conflict, support multidimensional arrays + + dummy := &StructMember{} + dummy.Type = mb.Type.TypeK + dummy.Key = "v" + gen.genWriteVar(dummy, "", hasRet) + + c.WriteString("}\n") +} + +func (gen *GenGo) genWriteStruct(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + tag := strconv.Itoa(int(mb.Tag)) + c.WriteString(` +err = ` + prefix + mb.Key + `.WriteBlock(buf, ` + tag + `) +` + errString(hasRet) + ` +`) +} + +func (gen *GenGo) genWriteMap(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + tag := strconv.Itoa(int(mb.Tag)) + vc := strconv.Itoa(gen.vc) + gen.vc++ + errStr := errString(hasRet) + c.WriteString(` +err = buf.WriteHead(codec.MAP, ` + tag + `) +` + errStr + ` +err = buf.WriteInt32(int32(len(` + gen.genVariableName(prefix, mb.Key) + `)), 0) +` + errStr + ` +for k` + vc + `, v` + vc + ` := range ` + gen.genVariableName(prefix, mb.Key) + ` { +`) + // for _, v := range can nesting for _, v := range,does not conflict, support multidimensional arrays + + dummy := &StructMember{} + dummy.Type = mb.Type.TypeK + dummy.Key = "k" + vc + gen.genWriteVar(dummy, "", hasRet) + + dummy = &StructMember{} + dummy.Type = mb.Type.TypeV + dummy.Key = "v" + vc + dummy.Tag = 1 + gen.genWriteVar(dummy, "", hasRet) + + c.WriteString("}\n") +} + +func (gen *GenGo) genWriteVar(v *StructMember, prefix string, hasRet bool) { + c := &gen.code + + switch v.Type.Type { + case tkTVector: + gen.genWriteVector(v, prefix, hasRet) + case tkTArray: + gen.genWriteArray(v, prefix, hasRet) + case tkTMap: + gen.genWriteMap(v, prefix, hasRet) + case tkName: + if v.Type.CType == tkEnum { + // tkEnum enumeration processing + tag := strconv.Itoa(int(v.Tag)) + c.WriteString(` +err = buf.WriteInt32(int32(` + gen.genVariableName(prefix, v.Key) + `),` + tag + `) +` + errString(hasRet) + ` +`) + } else { + gen.genWriteStruct(v, prefix, hasRet) + } + default: + tag := strconv.Itoa(int(v.Tag)) + c.WriteString(` +err = buf.Write` + upperFirstLetter(gen.genType(v.Type)) + `(` + gen.genVariableName(prefix, v.Key) + `, ` + tag + `) +` + errString(hasRet) + ` +`) + } +} + +func (gen *GenGo) genFunWriteBlock(st *StructInfo) { + c := &gen.code + + // WriteBlock function head + c.WriteString(`// WriteBlock encode struct +func (st *` + st.Name + `) WriteBlock(buf *codec.Buffer, tag byte) error { + var err error + err = buf.WriteHead(codec.StructBegin, tag) + if err != nil { + return err + } + + err = st.WriteTo(buf) + if err != nil { + return err + } + + err = buf.WriteHead(codec.StructEnd, 0) + if err != nil { + return err + } + return nil +} +`) +} + +func (gen *GenGo) genFunWriteTo(st *StructInfo) { + c := &gen.code + + c.WriteString(`// WriteTo encode struct to buffer +func (st *` + st.Name + `) WriteTo(buf *codec.Buffer) (err error) { +`) + for _, v := range st.Mb { + gen.genWriteVar(&v, "st.", false) + } + + c.WriteString(` + return err +} +`) +} + +func (gen *GenGo) genReadSimpleList(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + unsign := "Int8" + if mb.Type.TypeK.Unsigned { + unsign = "Uint8" + } + errStr := errString(hasRet) + + c.WriteString(` +_, err = readBuf.SkipTo(codec.BYTE, 0, true) +` + errStr + ` +err = readBuf.ReadInt32(&length, 0, true) +` + errStr + ` +err = readBuf.ReadSlice` + unsign + `(&` + prefix + mb.Key + `, length, true) +` + errStr + ` +`) +} + +func (gen *GenGo) genReadVector(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + errStr := errString(hasRet) + + // LIST + tag := strconv.Itoa(int(mb.Tag)) + vc := strconv.Itoa(gen.vc) + gen.vc++ + if mb.Require { + c.WriteString(` +_, ty, err = readBuf.SkipToNoCheck(` + tag + `, true) +` + errStr + ` +`) + } else { + c.WriteString(` +have, ty, err = readBuf.SkipToNoCheck(` + tag + `, false) +` + errStr + ` +if have {`) + // 结束标记 + defer func() { + c.WriteString("}\n") + }() + } + + c.WriteString(` +if ty == codec.LIST { + err = readBuf.ReadInt32(&length, 0, true) + ` + errStr + ` + ` + gen.genVariableName(prefix, mb.Key) + ` = make(` + gen.genType(mb.Type) + `, length) + ` + genForHead(vc) + `{ +`) + + dummy := &StructMember{} + dummy.Type = mb.Type.TypeK + dummy.Key = mb.Key + "[i" + vc + "]" + gen.genReadVar(dummy, prefix, hasRet) + + c.WriteString(`} +} else if ty == codec.SimpleList { +`) + if mb.Type.TypeK.Type == tkTByte { + gen.genReadSimpleList(mb, prefix, hasRet) + } else { + c.WriteString(`err = fmt.Errorf("not support SimpleList type") + ` + errStr) + } + c.WriteString(` +} else { + err = fmt.Errorf("require vector, but not") + ` + errStr + ` +} +`) +} + +func (gen *GenGo) genReadArray(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + errStr := errString(hasRet) + + // LIST + tag := strconv.Itoa(int(mb.Tag)) + vc := strconv.Itoa(gen.vc) + gen.vc++ + + if mb.Require { + c.WriteString(` +_, ty, err = readBuf.SkipToNoCheck(` + tag + `, true) +` + errStr + ` +`) + } else { + c.WriteString(` +have, ty, err = readBuf.SkipToNoCheck(` + tag + `, false) +` + errStr + ` +if have {`) + // 结束标记 + defer func() { + c.WriteString("}\n") + }() + } + + c.WriteString(` +if ty == codec.LIST { + err = readBuf.ReadInt32(&length, 0, true) + ` + errStr + ` + ` + genForHead(vc) + `{ +`) + + dummy := &StructMember{} + dummy.Type = mb.Type.TypeK + dummy.Key = mb.Key + "[i" + vc + "]" + gen.genReadVar(dummy, prefix, hasRet) + + c.WriteString(`} +} else if ty == codec.SimpleList { +`) + if mb.Type.TypeK.Type == tkTByte { + gen.genReadSimpleList(mb, prefix, hasRet) + } else { + c.WriteString(`err = fmt.Errorf("not support SimpleList type") + ` + errStr) + } + c.WriteString(` +} else { + err = fmt.Errorf("require array, but not") + ` + errStr + ` +} +`) +} + +func (gen *GenGo) genReadStruct(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + tag := strconv.Itoa(int(mb.Tag)) + require := "false" + if mb.Require { + require = "true" + } + c.WriteString(` +err = ` + prefix + mb.Key + `.ReadBlock(readBuf, ` + tag + `, ` + require + `) +` + errString(hasRet) + ` +`) +} + +func (gen *GenGo) genReadMap(mb *StructMember, prefix string, hasRet bool) { + c := &gen.code + tag := strconv.Itoa(int(mb.Tag)) + errStr := errString(hasRet) + vc := strconv.Itoa(gen.vc) + gen.vc++ + + if mb.Require { + c.WriteString(` +_, err = readBuf.SkipTo(codec.MAP, ` + tag + `, true) +` + errStr + ` +`) + } else { + c.WriteString(` +have, err = readBuf.SkipTo(codec.MAP, ` + tag + `, false) +` + errStr + ` +if have {`) + // 结束标记 + defer func() { + c.WriteString("}\n") + }() + } + + c.WriteString(` +err = readBuf.ReadInt32(&length, 0, true) +` + errStr + ` +` + gen.genVariableName(prefix, mb.Key) + ` = make(` + gen.genType(mb.Type) + `) +` + genForHead(vc) + `{ + var k` + vc + ` ` + gen.genType(mb.Type.TypeK) + ` + var v` + vc + ` ` + gen.genType(mb.Type.TypeV) + ` +`) + + dummy := &StructMember{} + dummy.Type = mb.Type.TypeK + dummy.Key = "k" + vc + gen.genReadVar(dummy, "", hasRet) + + dummy = &StructMember{} + dummy.Type = mb.Type.TypeV + dummy.Key = "v" + vc + dummy.Tag = 1 + gen.genReadVar(dummy, "", hasRet) + + c.WriteString(` + ` + prefix + mb.Key + `[k` + vc + `] = v` + vc + ` +} +`) +} + +func (gen *GenGo) genReadVar(v *StructMember, prefix string, hasRet bool) { + c := &gen.code + + switch v.Type.Type { + case tkTVector: + gen.genReadVector(v, prefix, hasRet) + case tkTArray: + gen.genReadArray(v, prefix, hasRet) + case tkTMap: + gen.genReadMap(v, prefix, hasRet) + case tkName: + if v.Type.CType == tkEnum { + require := "false" + if v.Require { + require = "true" + } + tag := strconv.Itoa(int(v.Tag)) + c.WriteString(` +err = readBuf.ReadInt32((*int32)(&` + prefix + v.Key + `),` + tag + `, ` + require + `) +` + errString(hasRet) + ` +`) + } else { + gen.genReadStruct(v, prefix, hasRet) + } + default: + require := "false" + if v.Require { + require = "true" + } + tag := strconv.Itoa(int(v.Tag)) + c.WriteString(` +err = readBuf.Read` + upperFirstLetter(gen.genType(v.Type)) + `(&` + prefix + v.Key + `, ` + tag + `, ` + require + `) +` + errString(hasRet) + ` +`) + } +} + +func (gen *GenGo) genFunReadFrom(st *StructInfo) { + c := &gen.code + + c.WriteString(`// ReadFrom reads from readBuf and put into struct. +func (st *` + st.Name + `) ReadFrom(readBuf *codec.Reader) error { + var ( + err error + length int32 + have bool + ty byte + ) + st.ResetDefault() + +`) + + for _, v := range st.Mb { + gen.genReadVar(&v, "st.", false) + } + + c.WriteString(` + _ = err + _ = length + _ = have + _ = ty + return nil +} +`) +} + +func (gen *GenGo) genFunReadBlock(st *StructInfo) { + c := &gen.code + + c.WriteString(`// ReadBlock reads struct from the given tag , require or optional. +func (st *` + st.Name + `) ReadBlock(readBuf *codec.Reader, tag byte, require bool) error { + var ( + err error + have bool + ) + st.ResetDefault() + + have, err = readBuf.SkipTo(codec.StructBegin, tag, require) + if err != nil { + return err + } + if !have { + if require { + return fmt.Errorf("require ` + st.Name + `, but not exist. tag %d", tag) + } + return nil + } + + err = st.ReadFrom(readBuf) + if err != nil { + return err + } + + err = readBuf.SkipToStructEnd() + if err != nil { + return err + } + _ = have + return nil +} +`) +} + +func (gen *GenGo) genStruct(st *StructInfo) { + gen.vc = 0 + st.rename() + + gen.genStructDefine(st) + gen.genFunResetDefault(st) + + gen.genFunReadFrom(st) + gen.genFunReadBlock(st) + + gen.genFunWriteTo(st) + gen.genFunWriteBlock(st) +} + +func (gen *GenGo) makeEnumName(en *EnumInfo, mb *EnumMember) string { + return upperFirstLetter(en.Name) + "_" + upperFirstLetter(mb.Key) +} + +func (gen *GenGo) genEnum(en *EnumInfo) { + if len(en.Mb) == 0 { + return + } + + en.rename() + + c := &gen.code + c.WriteString("type " + en.Name + " int32\n") + c.WriteString("const (\n") + var it int32 + for _, v := range en.Mb { + if v.Type == 0 { + //use value + c.WriteString(gen.makeEnumName(en, &v) + ` = ` + strconv.Itoa(int(v.Value)) + "\n") + it = v.Value + 1 + } else if v.Type == 1 { + // use name + find := false + for _, ref := range en.Mb { + if ref.Key == v.Name { + find = true + c.WriteString(gen.makeEnumName(en, &v) + ` = ` + gen.makeEnumName(en, &ref) + "\n") + it = ref.Value + 1 + break + } + if ref.Key == v.Key { + break + } + } + if !find { + gen.genErr(v.Name + " not define before use.") + } + } else { + // use auto add + c.WriteString(gen.makeEnumName(en, &v) + ` = ` + strconv.Itoa(int(it)) + "\n") + it++ + } + + } + + c.WriteString(")\n") +} + +func (gen *GenGo) genConst(cst []ConstInfo) { + if len(cst) == 0 { + return + } + + c := &gen.code + c.WriteString("//const as define in tars file\n") + c.WriteString("const (\n") + + for _, v := range gen.p.Const { + v.rename() + c.WriteString(v.Name + " " + gen.genType(v.Type) + " = " + v.Value + "\n") + } + + c.WriteString(")\n") +} + +func (gen *GenGo) genInclude(ps []*Parse) { + for _, v := range ps { + gen2 := &GenGo{ + path: v.Source, + module: gen.module, + prefix: gen.prefix, + tarsPath: gTarsPath, + ProtoName: path2ProtoName(v.Source), + } + gen2.p = v + gen2.genAll() + } +} + +func (gen *GenGo) genInterface(itf *InterfaceInfo) { + gen.code.Reset() + itf.rename() + + gen.genHead() + gen.genIFPackage(itf) + + gen.genIFProxy(itf) + + gen.genIFServer(itf) + gen.genIFServerWithContext(itf) + + gen.genIFDispatch(itf) + + gen.genTarsCallback(itf) + + gen.genSendPushResponse(itf) + + gen.saveToSourceFile(itf.Name + ".tars.go") +} + +func (gen *GenGo) genIFProxy(itf *InterfaceInfo) { + c := &gen.code + c.WriteString("// " + itf.Name + " struct\n") + c.WriteString("type " + itf.Name + " struct {" + "\n") + c.WriteString("servant m.Servant" + "\n") + c.WriteString("}" + "\n") + + for _, v := range itf.Fun { + gen.genIFProxyFun(itf.Name, &v, false, false) + gen.genIFProxyFun(itf.Name, &v, true, false) + gen.genIFProxyFun(itf.Name, &v, true, true) + } + + c.WriteString(`// SetServant sets servant for the service. +func (obj *` + itf.Name + `) SetServant(servant m.Servant) { + obj.servant = servant +} +`) + c.WriteString(`// GetServant gets servant for the service. +func (obj *` + itf.Name + `) GetServant()(servant *m.Servant) { + return &obj.servant +} +`) + c.WriteString(` // SetOnConnectCallback +func (obj *` + itf.Name + `) SetOnConnectCallback(callback func(string)) { + obj.servant.SetOnConnectCallback(callback) +} +`) + + c.WriteString(` // SetOnCloseCallback +func (obj *` + itf.Name + `) SetOnCloseCallback(callback func(string)) { + obj.servant.SetOnCloseCallback(callback) +} +`) + c.WriteString(` // SetTarsCallback +func (obj *` + itf.Name + `) SetTarsCallback(callback ` + itf.Name + `TarsCallback) { + var push ` + itf.Name + `PushCallback + push.Cb = callback + obj.servant.SetTarsCallback(&push) +} +`) + + c.WriteString(` // SetPushCallback +func (obj *` + itf.Name + `) SetPushCallback(callback func([]byte)) { + obj.servant.SetPushCallback(callback) +} +`) + + c.WriteString(`func (obj *` + itf.Name + `) req2Byte(rsp *requestf.ResponsePacket) []byte { + req := requestf.RequestPacket{} + req.IVersion = rsp.IVersion + req.IRequestId = rsp.IRequestId + req.IMessageType = rsp.IMessageType + req.CPacketType = rsp.CPacketType + req.Context = rsp.Context + req.Status = rsp.Status + req.SBuffer = rsp.SBuffer + + os := codec.NewBuffer() + req.WriteTo(os) + bs := os.ToBytes() + sbuf := bytes.NewBuffer(nil) + sbuf.Write(make([]byte, 4)) + sbuf.Write(bs) + length := sbuf.Len() + binary.BigEndian.PutUint32(sbuf.Bytes(), uint32(length)) + return sbuf.Bytes() +} + +func (obj * ` + itf.Name + `) rsp2Byte(rsp *requestf.ResponsePacket) []byte { + if rsp.IVersion == basef.TUPVERSION { + return obj.req2Byte(rsp) + } + os := codec.NewBuffer() + rsp.WriteTo(os) + bs := os.ToBytes() + sbuf := bytes.NewBuffer(nil) + sbuf.Write(make([]byte, 4)) + sbuf.Write(bs) + length := sbuf.Len() + binary.BigEndian.PutUint32(sbuf.Bytes(), uint32(length)) + return sbuf.Bytes() +} +`) + + c.WriteString(` // TarsPing +func (obj *` + itf.Name + `) TarsPing() { + ctx := context.Background() + obj.servant.TarsPing(ctx) +} +`) + + c.WriteString(`// TarsSetTimeout sets the timeout for the servant which is in ms. +func (obj *` + itf.Name + `) TarsSetTimeout(timeout int) { + obj.servant.TarsSetTimeout(timeout) +} +`) + + c.WriteString(`// TarsSetProtocol sets the protocol for the servant. +func (obj *` + itf.Name + `) TarsSetProtocol(p m.Protocol) { + obj.servant.TarsSetProtocol(p) +} +`) + + if *gAddServant { + c.WriteString(`// AddServant adds servant for the service. +func (obj *` + itf.Name + `) AddServant(imp ` + itf.Name + `Servant, servantObj string) { + tars.AddServant(obj, imp, servantObj) +} +`) + c.WriteString(`// AddServantWithContext adds servant for the service with context. +func (obj *` + itf.Name + `) AddServantWithContext(imp ` + itf.Name + `ServantWithContext, servantObj string) { + tars.AddServantWithContext(obj, imp, servantObj) +} +`) + } +} + +func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext bool, isOneWay bool) { + c := &gen.code + if withContext == true { + if isOneWay { + c.WriteString("// " + fun.Name + "OneWayWithContext is the proxy function for the method defined in the tars file, with the context\n") + c.WriteString("func (obj *" + interfName + ") " + fun.Name + "OneWayWithContext(tarsCtx context.Context,") + } else { + c.WriteString("// " + fun.Name + "WithContext is the proxy function for the method defined in the tars file, with the context\n") + c.WriteString("func (obj *" + interfName + ") " + fun.Name + "WithContext(tarsCtx context.Context,") + } + } else { + c.WriteString("// " + fun.Name + " is the proxy function for the method defined in the tars file, with the context\n") + c.WriteString("func (obj *" + interfName + ") " + fun.Name + "(") + } + for _, v := range fun.Args { + gen.genArgs(&v) + } + + c.WriteString(" opts ...map[string]string)") + if fun.HasRet { + c.WriteString("(ret " + gen.genType(fun.RetType) + ", err error){" + "\n") + } else { + c.WriteString("(err error)" + "{" + "\n") + } + + c.WriteString(` var ( + length int32 + have bool + ty byte + ) + `) + c.WriteString("buf := codec.NewBuffer()") + var isOut bool + for k, v := range fun.Args { + if v.IsOut { + isOut = true + } + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = int32(k + 1) + if v.IsOut { + dummy.Key = "(*" + dummy.Key + ")" + } + gen.genWriteVar(dummy, "", fun.HasRet) + } + // empty args and below separate + c.WriteString("\n") + errStr := errString(fun.HasRet) + + if !withContext { + c.WriteString(` +var statusMap map[string]string +var contextMap map[string]string +if len(opts) == 1{ + contextMap =opts[0] +}else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] +} +tarsResp := new(requestf.ResponsePacket) +tarsCtx := context.Background() +`) + } else { + // trace + if !isOneWay && !withoutTrace { + c.WriteString(` +traceData, ok := current.GetTraceData(tarsCtx) +if ok && traceData.TraceCall { + traceData.NewSpan() + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} +`) + for _, v := range fun.Args { + if !v.IsOut { + c.WriteString(`value["` + v.Name + `"] = ` + v.Name + "\n") + } + } + c.WriteString(`p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCS), trace.TraceAnnotationCS, tars.GetClientConfig().ModuleName, obj.servant.Name(), "` + fun.Name + `", 0, traceParam, "") +}`) + c.WriteString("\n\n") + } + c.WriteString(`var statusMap map[string]string +var contextMap map[string]string +if len(opts) == 1{ + contextMap =opts[0] +}else if len(opts) == 2 { + contextMap = opts[0] + statusMap = opts[1] +} + +tarsResp := new(requestf.ResponsePacket)`) + } + + if isOneWay { + c.WriteString(` + err = obj.servant.TarsInvoke(tarsCtx, 1, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp) + ` + errStr + ` + `) + } else { + c.WriteString(` + err = obj.servant.TarsInvoke(tarsCtx, 0, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp) + ` + errStr + ` + `) + } + + if (isOut || fun.HasRet) && !isOneWay { + c.WriteString("readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer))") + } + if fun.HasRet && !isOneWay { + dummy := &StructMember{} + dummy.Type = fun.RetType + dummy.Key = "ret" + dummy.Tag = 0 + dummy.Require = true + gen.genReadVar(dummy, "", fun.HasRet) + } + + if !isOneWay { + for k, v := range fun.Args { + if v.IsOut { + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = "(*" + v.Name + ")" + dummy.Tag = int32(k + 1) + dummy.Require = true + gen.genReadVar(dummy, "", fun.HasRet) + } + } + if withContext && !withoutTrace { + traceParamFlag := "traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(0))" + if isOut || fun.HasRet { + traceParamFlag = "traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(readBuf.Len()))" + } + c.WriteString(` +if ok && traceData.TraceCall { + var traceParam string + ` + traceParamFlag + ` + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} +`) + if fun.HasRet { + c.WriteString(`value[""] = ret` + "\n") + } + for _, v := range fun.Args { + if v.IsOut { + c.WriteString(`value["` + v.Name + `"] = *` + v.Name + "\n") + } + } + c.WriteString(`p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCR), trace.TraceAnnotationCR, tars.GetClientConfig().ModuleName, obj.servant.Name(), "` + fun.Name + `", int(tarsResp.IRet), traceParam, "") +}`) + c.WriteString("\n\n") + } + } + + c.WriteString(` +if len(opts) == 1 { + for k := range(contextMap){ + delete(contextMap, k) + } + for k, v := range(tarsResp.Context){ + contextMap[k] = v + } +} else if len(opts) == 2 { + for k := range(contextMap){ + delete(contextMap, k) + } + for k, v := range(tarsResp.Context){ + contextMap[k] = v + } + for k := range(statusMap){ + delete(statusMap, k) + } + for k, v := range(tarsResp.Status){ + statusMap[k] = v + } +} + _ = length + _ = have + _ = ty + `) + + if fun.HasRet { + c.WriteString("return ret, nil" + "\n") + } else { + c.WriteString("return nil" + "\n") + } + + c.WriteString("}" + "\n") +} + +func (gen *GenGo) genArgs(arg *ArgInfo) { + c := &gen.code + c.WriteString(arg.Name + " ") + if arg.IsOut || arg.Type.CType == tkStruct { + c.WriteString("*") + } + + c.WriteString(gen.genType(arg.Type) + ",") +} + +func (gen *GenGo) genIFServer(itf *InterfaceInfo) { + c := &gen.code + c.WriteString("type " + itf.Name + "Servant interface {" + "\n") + for _, v := range itf.Fun { + gen.genIFServerFun(&v) + } + c.WriteString("}" + "\n") +} + +func (gen *GenGo) genIFServerWithContext(itf *InterfaceInfo) { + c := &gen.code + c.WriteString("type " + itf.Name + "ServantWithContext interface {" + "\n") + for _, v := range itf.Fun { + gen.genIFServerFunWithContext(&v) + } + c.WriteString("}" + "\n") +} + +func (gen *GenGo) genIFServerFun(fun *FunInfo) { + c := &gen.code + c.WriteString(fun.Name + "(") + for _, v := range fun.Args { + gen.genArgs(&v) + } + c.WriteString(")(") + + if fun.HasRet { + c.WriteString("ret " + gen.genType(fun.RetType) + ", ") + } + c.WriteString("err error)" + "\n") +} + +func (gen *GenGo) genIFServerFunWithContext(fun *FunInfo) { + c := &gen.code + c.WriteString(fun.Name + "(tarsCtx context.Context, ") + for _, v := range fun.Args { + gen.genArgs(&v) + } + c.WriteString(")(") + + if fun.HasRet { + c.WriteString("ret " + gen.genType(fun.RetType) + ", ") + } + c.WriteString("err error)" + "\n") +} + +func (gen *GenGo) genIFDispatch(itf *InterfaceInfo) { + c := &gen.code + c.WriteString("// Dispatch is used to call the server side implement for the method defined in the tars file. withContext shows using context or not. \n") + c.WriteString("func(obj *" + itf.Name + `) Dispatch(tarsCtx context.Context, val interface{}, tarsReq *requestf.RequestPacket, tarsResp *requestf.ResponsePacket, withContext bool) (err error) { + var ( + length int32 + have bool + ty byte + ) + `) + + var param bool + for _, v := range itf.Fun { + if len(v.Args) > 0 { + param = true + break + } + } + + if param { + c.WriteString("readBuf := codec.NewReader(tools.Int8ToByte(tarsReq.SBuffer))") + } else { + c.WriteString("readBuf := codec.NewReader(nil)") + } + c.WriteString(` + buf := codec.NewBuffer() + switch tarsReq.SFuncName { +`) + + for _, v := range itf.Fun { + gen.genSwitchCase(itf.Name, &v) + } + + c.WriteString(` + default: + return fmt.Errorf("func mismatch") + } + var statusMap map[string]string + if status, ok := current.GetResponseStatus(tarsCtx); ok && status != nil { + statusMap = status + } + var contextMap map[string]string + if ctx, ok := current.GetResponseContext(tarsCtx); ok && ctx != nil { + contextMap = ctx + } + *tarsResp = requestf.ResponsePacket{ + IVersion: tarsReq.IVersion, + CPacketType: 0, + IRequestId: tarsReq.IRequestId, + IMessageType: 0, + IRet: 0, + SBuffer: tools.ByteToInt8(buf.ToBytes()), + Status: statusMap, + SResultDesc: "", + Context: contextMap, + } + + _ = readBuf + _ = buf + _ = length + _ = have + _ = ty + return nil +} +`) +} + +func (gen *GenGo) genSwitchCase(tname string, fun *FunInfo) { + c := &gen.code + c.WriteString(`case "` + fun.OriginName + `":` + "\n") + + inArgsCount := 0 + outArgsCount := 0 + for _, v := range fun.Args { + c.WriteString("var " + v.Name + " " + gen.genType(v.Type) + "\n") + if v.Type.Type == tkTMap { + c.WriteString(v.Name + " = make(" + gen.genType(v.Type) + ")\n") + } else if v.Type.Type == tkTVector { + c.WriteString(v.Name + " = make(" + gen.genType(v.Type) + ", 0)\n") + } + if v.IsOut { + outArgsCount++ + } else { + inArgsCount++ + } + } + + //fmt.Println("args count, in, out:", inArgsCount, outArgsCount) + + c.WriteString("\n") + + if inArgsCount > 0 { + c.WriteString("if tarsReq.IVersion == basef.TARSVERSION {" + "\n") + + for k, v := range fun.Args { + + if !v.IsOut { + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = int32(k + 1) + dummy.Require = true + gen.genReadVar(dummy, "", false) + } + } + + c.WriteString(`} else if tarsReq.IVersion == basef.TUPVERSION { + reqTup := tup.NewUniAttribute() + reqTup.Decode(readBuf) + + var tupBuffer []byte + + `) + for _, v := range fun.Args { + if !v.IsOut { + c.WriteString("\n") + c.WriteString(`reqTup.GetBuffer("` + v.Name + `", &tupBuffer)` + "\n") + c.WriteString("readBuf.Reset(tupBuffer)") + + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = 0 + dummy.Require = true + gen.genReadVar(dummy, "", false) + } + } + + c.WriteString(`} else if tarsReq.IVersion == basef.JSONVERSION { + var jsonData map[string]interface{} + decoder := json.NewDecoder(bytes.NewReader(readBuf.ToBytes())) + decoder.UseNumber() + err = decoder.Decode(&jsonData) + if err != nil { + return fmt.Errorf("decode reqpacket failed, error: %+v", err) + } + `) + + for _, v := range fun.Args { + if !v.IsOut { + c.WriteString("{\n") + c.WriteString(`jsonStr, _ := json.Marshal(jsonData["` + v.Name + `"])` + "\n") + if v.Type.CType == tkStruct { + c.WriteString(v.Name + ".ResetDefault()\n") + } + c.WriteString("if err = json.Unmarshal(jsonStr, &" + v.Name + "); err != nil {") + c.WriteString(` + return err + } + } + `) + } + } + + c.WriteString(` + } else { + err = fmt.Errorf("decode reqpacket fail, error version: %d", tarsReq.IVersion) + return err + }`) + + c.WriteString("\n\n") + } + if !withoutTrace { + c.WriteString(` +traceData, ok := current.GetTraceData(tarsCtx) +if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} +`) + for _, v := range fun.Args { + if !v.IsOut { + c.WriteString(`value["` + v.Name + `"] = ` + v.Name + "\n") + } + } + c.WriteString(`p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSR), trace.TraceAnnotationSR, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "` + fun.OriginName + `", 0, traceParam, "") +}`) + c.WriteString("\n\n") + } + + if fun.HasRet { + c.WriteString("var funRet " + gen.genType(fun.RetType) + "\n") + + c.WriteString(`if !withContext { + imp := val.(` + tname + `Servant) + funRet, err = imp.` + fun.Name + `(`) + for _, v := range fun.Args { + if v.IsOut || v.Type.CType == tkStruct { + c.WriteString("&" + v.Name + ",") + } else { + c.WriteString(v.Name + ",") + } + } + c.WriteString(")") + + c.WriteString(` + } else { + imp := val.(` + tname + `ServantWithContext) + funRet, err = imp.` + fun.Name + `(tarsCtx ,`) + for _, v := range fun.Args { + if v.IsOut || v.Type.CType == tkStruct { + c.WriteString("&" + v.Name + ",") + } else { + c.WriteString(v.Name + ",") + } + } + c.WriteString(")" + "\n } \n") + + } else { + c.WriteString(`if !withContext { + imp := val.(` + tname + `Servant) + err = imp.` + fun.Name + `(`) + for _, v := range fun.Args { + if v.IsOut || v.Type.CType == tkStruct { + c.WriteString("&" + v.Name + ",") + } else { + c.WriteString(v.Name + ",") + } + } + c.WriteString(")") + + c.WriteString(` + } else { + imp := val.(` + tname + `ServantWithContext) + err = imp.` + fun.Name + `(tarsCtx ,`) + for _, v := range fun.Args { + if v.IsOut || v.Type.CType == tkStruct { + c.WriteString("&" + v.Name + ",") + } else { + c.WriteString(v.Name + ",") + } + } + c.WriteString(") \n}\n") + } + + if *dispatchReporter { + var inArgStr, outArgStr, retArgStr string + if fun.HasRet { + retArgStr = "funRet, err" + } else { + retArgStr = "err" + } + for _, v := range fun.Args { + prefix := "" + if v.Type.CType == tkStruct { + prefix = "&" + } + if v.IsOut { + outArgStr += prefix + v.Name + "," + } else { + inArgStr += prefix + v.Name + "," + } + } + c.WriteString(`if dp := tars.GetDispatchReporter(); dp != nil { + dp(tarsCtx, []interface{}{` + inArgStr + `}, []interface{}{` + outArgStr + `}, []interface{}{` + retArgStr + `}) + }`) + + } + c.WriteString(` + if err != nil { + return err + } + `) + + c.WriteString(` + if tarsReq.IVersion == basef.TARSVERSION { + buf.Reset() + `) + + // if fun.HasRet { + // c.WriteString(` + // err = buf.WriteInt32(funRet, 0) + // if err != nil { + // return err + // } + //`) + // } + + if fun.HasRet { + dummy := &StructMember{} + dummy.Type = fun.RetType + dummy.Key = "funRet" + dummy.Tag = 0 + dummy.Require = true + gen.genWriteVar(dummy, "", false) + } + + for k, v := range fun.Args { + if v.IsOut { + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = int32(k + 1) + dummy.Require = true + gen.genWriteVar(dummy, "", false) + } + } + + c.WriteString(` +} else if tarsReq.IVersion == basef.TUPVERSION { +rspTup := tup.NewUniAttribute() +`) + if fun.HasRet { + dummy := &StructMember{} + dummy.Type = fun.RetType + dummy.Key = "funRet" + dummy.Tag = 0 + dummy.Require = true + gen.genWriteVar(dummy, "", false) + + c.WriteString(` + rspTup.PutBuffer("", buf.ToBytes()) + rspTup.PutBuffer("tars_ret", buf.ToBytes()) +`) + } + + for _, v := range fun.Args { + if v.IsOut { + c.WriteString(` + buf.Reset()`) + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = 0 + dummy.Require = true + gen.genWriteVar(dummy, "", false) + + c.WriteString(`rspTup.PutBuffer("` + v.Name + `", buf.ToBytes())` + "\n") + } + } + + c.WriteString(` + buf.Reset() + err = rspTup.Encode(buf) + if err != nil { + return err + } +} else if tarsReq.IVersion == basef.JSONVERSION { + rspJson := map[string]interface{}{} +`) + if fun.HasRet { + c.WriteString(`rspJson["tars_ret"] = funRet` + "\n") + } + + for _, v := range fun.Args { + if v.IsOut { + c.WriteString(`rspJson["` + v.Name + `"] = ` + v.Name + "\n") + } + } + + c.WriteString(` + var rspByte []byte + if rspByte, err = json.Marshal(rspJson); err != nil { + return err + } + + buf.Reset() + err = buf.WriteSliceUint8(rspByte) + if err != nil { + return err + } +}`) + + c.WriteString("\n") + if !withoutTrace { + c.WriteString(` +if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstSS, uint(buf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} +`) + if fun.HasRet { + c.WriteString(`value[""] = funRet` + "\n") + } + for _, v := range fun.Args { + if v.IsOut { + c.WriteString(`value["` + v.Name + `"] = ` + v.Name + "\n") + } + } + c.WriteString(`p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstSS), trace.TraceAnnotationSS, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "` + fun.OriginName + `", 0, traceParam, "") +}`) + c.WriteString("\n\n") + } +} + +func (gen *GenGo) genArgsForPush(arg *ArgInfo) { + c := &gen.code + if arg.IsOut { + c.WriteString(arg.Name + " ") + c.WriteString("*") + c.WriteString(gen.genType(arg.Type) + ",") + } +} + +func (gen *GenGo) genTarsCallback(itf *InterfaceInfo) { + c := &gen.code + + c.WriteString("type " + itf.Name + "TarsCallback interface {" + "\n") + for _, v := range itf.Fun { + gen.genIFPushCallback(itf.Name, &v) + gen.genIFPushExceptionCallback(itf.Name, &v) + } + c.WriteString("}" + "\n") + + c.WriteString("// " + itf.Name + "PushCallback struct\n") + c.WriteString("type " + itf.Name + "PushCallback struct {" + "\n") + c.WriteString("Cb " + itf.Name + "TarsCallback" + "\n") + c.WriteString("}" + "\n") + + // 生成PushCallback 接口Ondispatch + c.WriteString("func (cb *" + itf.Name + "PushCallback) Ondispatch(resp *requestf.ResponsePacket) {" + "\n") + c.WriteString("switch resp.SResultDesc {" + "\n") + for _, v := range itf.Fun { + var hasOut bool + for _, v := range v.Args { + if v.IsOut { + hasOut = true + } + } + if !hasOut && !v.HasRet { + continue + } + c.WriteString("case \"" + v.Name + "\":" + "\n") + c.WriteString("err := func() error {" + "\n") + c.WriteString("var err error" + "\n") + c.WriteString("readBuf := codec.NewReader(tools.Int8ToByte(resp.SBuffer))" + "\n") + if v.HasRet { + c.WriteString(" var ret = new(" + gen.genType(v.RetType) + ")") + dummy := &StructMember{} + dummy.Type = v.RetType + dummy.Key = "ret" + dummy.Tag = 0 + dummy.Require = true + gen.genReadVar(dummy, "", false) + } + for k, arg := range v.Args { + if arg.IsOut { + hasOut = true + if len(arg.Type.TypeSt) == 0 { + c.WriteString("var " + arg.Name + "= new(" + gen.genType(arg.Type) + ")" + "\n") + } else { + c.WriteString("var " + arg.Name + "= new(" + arg.Type.TypeSt + ")" + "\n") + } + dummy := &StructMember{} + dummy.Type = arg.Type + dummy.Key = "(*" + arg.Name + ")" + dummy.Tag = int32(k + 1) + dummy.Require = true + gen.genReadVar(dummy, "", false) + } + } + c.WriteString(` if resp.Context != nil { + cb.Cb.` + v.Name + `_Callback(`) + k := 0 + if v.HasRet { + c.WriteString("ret") + k++ + } + for _, arg := range v.Args { + if arg.IsOut { + if k == 0 { + c.WriteString(arg.Name) + } else { + c.WriteString(`, ` + arg.Name) + } + k++ + } + } + c.WriteString(", resp.Context)" + "\n") + c.WriteString("return nil") + c.WriteString(`} else { `) + c.WriteString(`cb.Cb.` + v.Name + `_Callback(`) + { + k := 0 + if v.HasRet { + c.WriteString("ret") + k++ + } + for _, arg := range v.Args { + if arg.IsOut { + if k == 0 { + c.WriteString(arg.Name) + } else { + c.WriteString(`, ` + arg.Name) + } + k++ + } + } + } + c.WriteString(")" + "\n") + c.WriteString("return nil") + c.WriteString("}" + "\n") + c.WriteString("}()" + "\n") + c.WriteString(`if err != nil { + cb.Cb.` + v.Name + `_ExceptionCallback(err) +}` + "\n") + } + + c.WriteString("}" + "\n") + c.WriteString("}" + "\n") +} + +func (gen *GenGo) genIFPushCallback(name string, f *FunInfo) { + c := &gen.code + c.WriteString(f.Name + "_Callback(") + if f.HasRet { + c.WriteString("ret *" + gen.genType(f.RetType) + ", ") + } + for _, v := range f.Args { + gen.genArgsForPush(&v) + } + + c.WriteString(" opt ...map[string]string)" + "\n") +} + +func (gen *GenGo) genIFPushExceptionCallback(name string, f *FunInfo) { + c := &gen.code + c.WriteString(f.Name + "_ExceptionCallback(err error)" + "\n") +} + +func (gen *GenGo) genSendPushResponse(itf *InterfaceInfo) { + c := &gen.code + for _, fun := range itf.Fun { + var hasOut bool + for _, v := range fun.Args { + if v.IsOut { + hasOut = true + } + } + if !hasOut && !fun.HasRet { + continue + } + + c.WriteString(`func (obj *` + itf.Name + `)AsyncSendResponse_` + fun.Name + `(ctx context.Context, `) + if fun.HasRet { + c.WriteString("ret *" + gen.genType(fun.RetType) + ", ") + } + for _, arg := range fun.Args { + gen.genArgsForPush(&arg) + } + c.WriteString(" opt ... map[string]string") + c.WriteString(`)(err error) {` + "\n") + c.WriteString(` + conn, udpAddr, ok := current.GetRawConn(ctx) + if !ok { + return fmt.Errorf("connection not found") + } +`) + c.WriteString("buf := codec.NewBuffer()" + "\n") + if fun.HasRet { + dummy := &StructMember{} + dummy.Type = fun.RetType + dummy.Key = "ret" + dummy.Tag = 0 + dummy.Require = true + gen.genWriteVar(dummy, "", false) + } + for k, v := range fun.Args { + if !v.IsOut { + continue + } + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = int32(k + 1) + if v.IsOut { + dummy.Key = "(*" + dummy.Key + ")" + } + gen.genWriteVar(dummy, "", false) + } + c.WriteString(`resp := &requestf.ResponsePacket{ + SBuffer: tools.ByteToInt8(buf.ToBytes()), + } + resp.IVersion = basef.TARSVERSION + if resp.Status == nil { + resp.Status = make(map[string]string) + } + resp.Status["TARS_FUNC"] = "` + fun.Name + `" + resp.SResultDesc = "` + fun.Name + `" + if len(opt) > 0 { + if opt[0] != nil{ + resp.Context= opt[0] + } + } + rspData := obj.rsp2Byte(resp) + if udpAddr != nil { + udpConn, _ := conn.(*net.UDPConn) + _, err = udpConn.WriteToUDP(rspData, udpAddr) + } else { + _, err = conn.Write(rspData) + } + return err`) + c.WriteString("}" + "\n") + } +} diff --git a/tars/transport/common.go b/tars/transport/common.go index e6de10d5..0ab2ec9b 100755 --- a/tars/transport/common.go +++ b/tars/transport/common.go @@ -35,6 +35,8 @@ type ServerProtocol interface { type ClientProtocol interface { Recv(pkg []byte) ParsePackage(buff []byte) (int, int) + OnClose(address string) + OnConnect(address string) } func isNoDataError(err error) bool { diff --git a/tars/transport/tarsclient.go b/tars/transport/tarsclient.go index c589f1c8..6b74e10e 100755 --- a/tars/transport/tarsclient.go +++ b/tars/transport/tarsclient.go @@ -135,15 +135,17 @@ func (c *connection) ReConnect() (err error) { } if err != nil { + go c.client.protocol.OnClose(c.client.address) return err } if c.client.config.Proto == "tcp" { if c.conn != nil { - _ = c.conn.(*net.TCPConn).SetKeepAlive(true) + c.conn.(*net.TCPConn).SetKeepAlive(true) } } c.idleTime = time.Now() c.isClosed = false + go c.client.protocol.OnConnect(c.client.address) connDone := make(chan bool, 1) go c.recv(c.conn, connDone) go c.send(c.conn, connDone) @@ -270,4 +272,5 @@ func (c *connection) close(conn net.Conn) { if conn != nil { _ = conn.Close() } + go c.client.protocol.OnClose(c.client.address) } diff --git a/tars/transport/tcphandler.go b/tars/transport/tcphandler.go index 380a3197..652c0d69 100755 --- a/tars/transport/tcphandler.go +++ b/tars/transport/tcphandler.go @@ -3,6 +3,7 @@ package transport import ( "context" "crypto/tls" + "github.com/google/uuid" "io" "net" "os" @@ -32,12 +33,14 @@ type tcpHandler struct { } type connInfo struct { + uuid string conn net.Conn idleTime int64 numInvoke int32 } func (t *tcpHandler) Listen() (err error) { + uuid.EnableRandPool() cfg := t.config t.listener, err = grace.CreateListener("tcp", cfg.Address) if err != nil { @@ -62,6 +65,7 @@ func (t *tcpHandler) Listen() (err error) { func (t *tcpHandler) getConnContext(connSt *connInfo) context.Context { ctx := current.ContextWithTarsCurrent(context.Background()) ipPort := strings.Split(connSt.conn.RemoteAddr().String(), ":") + current.SetUUIDWithContext(ctx, connSt.uuid) current.SetClientIPWithContext(ctx, ipPort[0]) current.SetClientPortWithContext(ctx, ipPort[1]) current.SetRecvPkgTsFromContext(ctx, time.Now().UnixNano()/1e6) @@ -137,7 +141,7 @@ func (t *tcpHandler) Handle() error { case *tls.Conn: TLOG.Debugf("TLS accept: %s, %d", conn.RemoteAddr(), os.Getpid()) } - cf := &connInfo{conn: conn} + cf := &connInfo{conn: conn, uuid: uuid.New().String()} t.conns.Store(key, cf) t.recv(cf) t.conns.Delete(key) diff --git a/tars/util/current/tarscurrent.go b/tars/util/current/tarscurrent.go index 47645def..cf61c309 100644 --- a/tars/util/current/tarscurrent.go +++ b/tars/util/current/tarscurrent.go @@ -14,6 +14,7 @@ var tcKey = tarsCurrentKey(0x484900) // Current contains message for the specify request. // This current is used for server side. type Current struct { + uuid string clientIP string clientPort string recvPkgTs int64 @@ -84,6 +85,23 @@ func currentFromContext(ctx context.Context) (*Current, bool) { return tc, ok } +func SetUUIDWithContext(ctx context.Context, uid string) bool { + tc, ok := currentFromContext(ctx) + if ok { + tc.uuid = uid + } + return ok +} + +// GetUUID get current uuid using for connect uniqe mark +func GetUUID(ctx context.Context) (uuid string, ret bool) { + tc, ok := currentFromContext(ctx) + if ok { + return tc.uuid, true + } + return +} + // SetResponseStatus set the response package' status . func SetResponseStatus(ctx context.Context, s map[string]string) bool { tc, ok := currentFromContext(ctx)