|
| 1 | +// Copyright 2018 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +// Package gengogrpc contains the gRPC code generator. |
| 6 | +package gengogrpc |
| 7 | + |
| 8 | +import ( |
| 9 | + "fmt" |
| 10 | + "strconv" |
| 11 | + "strings" |
| 12 | + |
| 13 | + "google.golang.org/protobuf/compiler/protogen" |
| 14 | + |
| 15 | + "google.golang.org/protobuf/types/descriptorpb" |
| 16 | +) |
| 17 | + |
| 18 | +const ( |
| 19 | + contextPackage = protogen.GoImportPath("context") |
| 20 | + grpcPackage = protogen.GoImportPath("google.golang.org/grpc") |
| 21 | + codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes") |
| 22 | + statusPackage = protogen.GoImportPath("google.golang.org/grpc/status") |
| 23 | +) |
| 24 | + |
| 25 | +// GenerateFile generates a _grpc.pb.go file containing gRPC service definitions. |
| 26 | +func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile { |
| 27 | + if len(file.Services) == 0 { |
| 28 | + return nil |
| 29 | + } |
| 30 | + filename := file.GeneratedFilenamePrefix + "_grpc.pb.go" |
| 31 | + g := gen.NewGeneratedFile(filename, file.GoImportPath) |
| 32 | + g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.") |
| 33 | + g.P() |
| 34 | + g.P("package ", file.GoPackageName) |
| 35 | + g.P() |
| 36 | + GenerateFileContent(gen, file, g) |
| 37 | + return g |
| 38 | +} |
| 39 | + |
| 40 | +// GenerateFileContent generates the gRPC service definitions, excluding the package statement. |
| 41 | +func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { |
| 42 | + if len(file.Services) == 0 { |
| 43 | + return |
| 44 | + } |
| 45 | + |
| 46 | + // TODO: Remove this. We don't need to include these references any more. |
| 47 | + g.P("// Reference imports to suppress errors if they are not otherwise used.") |
| 48 | + g.P("var _ ", contextPackage.Ident("Context")) |
| 49 | + g.P("var _ ", grpcPackage.Ident("ClientConnInterface")) |
| 50 | + g.P() |
| 51 | + |
| 52 | + g.P("// This is a compile-time assertion to ensure that this generated file") |
| 53 | + g.P("// is compatible with the grpc package it is being compiled against.") |
| 54 | + g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6")) |
| 55 | + g.P() |
| 56 | + for _, service := range file.Services { |
| 57 | + genService(gen, file, g, service) |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) { |
| 62 | + clientName := service.GoName + "Client" |
| 63 | + |
| 64 | + g.P("// ", clientName, " is the client API for ", service.GoName, " service.") |
| 65 | + g.P("//") |
| 66 | + g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.") |
| 67 | + |
| 68 | + // Client interface. |
| 69 | + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { |
| 70 | + g.P("//") |
| 71 | + g.P(deprecationComment) |
| 72 | + } |
| 73 | + g.Annotate(clientName, service.Location) |
| 74 | + g.P("type ", clientName, " interface {") |
| 75 | + for _, method := range service.Methods { |
| 76 | + g.Annotate(clientName+"."+method.GoName, method.Location) |
| 77 | + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { |
| 78 | + g.P(deprecationComment) |
| 79 | + } |
| 80 | + g.P(method.Comments.Leading, |
| 81 | + clientSignature(g, method)) |
| 82 | + } |
| 83 | + g.P("}") |
| 84 | + g.P() |
| 85 | + |
| 86 | + // Client structure. |
| 87 | + g.P("type ", unexport(clientName), " struct {") |
| 88 | + g.P("cc ", grpcPackage.Ident("ClientConnInterface")) |
| 89 | + g.P("}") |
| 90 | + g.P() |
| 91 | + |
| 92 | + // NewClient factory. |
| 93 | + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { |
| 94 | + g.P(deprecationComment) |
| 95 | + } |
| 96 | + g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {") |
| 97 | + g.P("return &", unexport(clientName), "{cc}") |
| 98 | + g.P("}") |
| 99 | + g.P() |
| 100 | + |
| 101 | + var methodIndex, streamIndex int |
| 102 | + // Client method implementations. |
| 103 | + for _, method := range service.Methods { |
| 104 | + if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { |
| 105 | + // Unary RPC method |
| 106 | + genClientMethod(gen, file, g, method, methodIndex) |
| 107 | + methodIndex++ |
| 108 | + } else { |
| 109 | + // Streaming RPC method |
| 110 | + genClientMethod(gen, file, g, method, streamIndex) |
| 111 | + streamIndex++ |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + // Server interface. |
| 116 | + serverType := service.GoName + "Server" |
| 117 | + g.P("// ", serverType, " is the server API for ", service.GoName, " service.") |
| 118 | + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { |
| 119 | + g.P("//") |
| 120 | + g.P(deprecationComment) |
| 121 | + } |
| 122 | + g.Annotate(serverType, service.Location) |
| 123 | + g.P("type ", serverType, " interface {") |
| 124 | + for _, method := range service.Methods { |
| 125 | + g.Annotate(serverType+"."+method.GoName, method.Location) |
| 126 | + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { |
| 127 | + g.P(deprecationComment) |
| 128 | + } |
| 129 | + g.P(method.Comments.Leading, |
| 130 | + serverSignature(g, method)) |
| 131 | + } |
| 132 | + g.P("}") |
| 133 | + g.P() |
| 134 | + |
| 135 | + // Server Unimplemented struct for forward compatibility. |
| 136 | + g.P("// Unimplemented", serverType, " can be embedded to have forward compatible implementations.") |
| 137 | + g.P("type Unimplemented", serverType, " struct {") |
| 138 | + g.P("}") |
| 139 | + g.P() |
| 140 | + for _, method := range service.Methods { |
| 141 | + nilArg := "" |
| 142 | + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| 143 | + nilArg = "nil," |
| 144 | + } |
| 145 | + g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{") |
| 146 | + g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`) |
| 147 | + g.P("}") |
| 148 | + } |
| 149 | + g.P() |
| 150 | + |
| 151 | + // Server registration. |
| 152 | + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { |
| 153 | + g.P(deprecationComment) |
| 154 | + } |
| 155 | + serviceDescVar := "_" + service.GoName + "_serviceDesc" |
| 156 | + g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {") |
| 157 | + g.P("s.RegisterService(&", serviceDescVar, `, srv)`) |
| 158 | + g.P("}") |
| 159 | + g.P() |
| 160 | + |
| 161 | + // Server handler implementations. |
| 162 | + var handlerNames []string |
| 163 | + for _, method := range service.Methods { |
| 164 | + hname := genServerMethod(gen, file, g, method) |
| 165 | + handlerNames = append(handlerNames, hname) |
| 166 | + } |
| 167 | + |
| 168 | + // Service descriptor. |
| 169 | + g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {") |
| 170 | + g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",") |
| 171 | + g.P("HandlerType: (*", serverType, ")(nil),") |
| 172 | + g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{") |
| 173 | + for i, method := range service.Methods { |
| 174 | + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { |
| 175 | + continue |
| 176 | + } |
| 177 | + g.P("{") |
| 178 | + g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",") |
| 179 | + g.P("Handler: ", handlerNames[i], ",") |
| 180 | + g.P("},") |
| 181 | + } |
| 182 | + g.P("},") |
| 183 | + g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{") |
| 184 | + for i, method := range service.Methods { |
| 185 | + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| 186 | + continue |
| 187 | + } |
| 188 | + g.P("{") |
| 189 | + g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",") |
| 190 | + g.P("Handler: ", handlerNames[i], ",") |
| 191 | + if method.Desc.IsStreamingServer() { |
| 192 | + g.P("ServerStreams: true,") |
| 193 | + } |
| 194 | + if method.Desc.IsStreamingClient() { |
| 195 | + g.P("ClientStreams: true,") |
| 196 | + } |
| 197 | + g.P("},") |
| 198 | + } |
| 199 | + g.P("},") |
| 200 | + g.P("Metadata: \"", file.Desc.Path(), "\",") |
| 201 | + g.P("}") |
| 202 | + g.P() |
| 203 | +} |
| 204 | + |
| 205 | +func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { |
| 206 | + s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) |
| 207 | + if !method.Desc.IsStreamingClient() { |
| 208 | + s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent) |
| 209 | + } |
| 210 | + s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") (" |
| 211 | + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| 212 | + s += "*" + g.QualifiedGoIdent(method.Output.GoIdent) |
| 213 | + } else { |
| 214 | + s += method.Parent.GoName + "_" + method.GoName + "Client" |
| 215 | + } |
| 216 | + s += ", error)" |
| 217 | + return s |
| 218 | +} |
| 219 | + |
| 220 | +func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) { |
| 221 | + service := method.Parent |
| 222 | + sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) |
| 223 | + |
| 224 | + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { |
| 225 | + g.P(deprecationComment) |
| 226 | + } |
| 227 | + g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{") |
| 228 | + if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { |
| 229 | + g.P("out := new(", method.Output.GoIdent, ")") |
| 230 | + g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`) |
| 231 | + g.P("if err != nil { return nil, err }") |
| 232 | + g.P("return out, nil") |
| 233 | + g.P("}") |
| 234 | + g.P() |
| 235 | + return |
| 236 | + } |
| 237 | + streamType := unexport(service.GoName) + method.GoName + "Client" |
| 238 | + serviceDescVar := "_" + service.GoName + "_serviceDesc" |
| 239 | + g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`) |
| 240 | + g.P("if err != nil { return nil, err }") |
| 241 | + g.P("x := &", streamType, "{stream}") |
| 242 | + if !method.Desc.IsStreamingClient() { |
| 243 | + g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }") |
| 244 | + g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") |
| 245 | + } |
| 246 | + g.P("return x, nil") |
| 247 | + g.P("}") |
| 248 | + g.P() |
| 249 | + |
| 250 | + genSend := method.Desc.IsStreamingClient() |
| 251 | + genRecv := method.Desc.IsStreamingServer() |
| 252 | + genCloseAndRecv := !method.Desc.IsStreamingServer() |
| 253 | + |
| 254 | + // Stream auxiliary types and methods. |
| 255 | + g.P("type ", service.GoName, "_", method.GoName, "Client interface {") |
| 256 | + if genSend { |
| 257 | + g.P("Send(*", method.Input.GoIdent, ") error") |
| 258 | + } |
| 259 | + if genRecv { |
| 260 | + g.P("Recv() (*", method.Output.GoIdent, ", error)") |
| 261 | + } |
| 262 | + if genCloseAndRecv { |
| 263 | + g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)") |
| 264 | + } |
| 265 | + g.P(grpcPackage.Ident("ClientStream")) |
| 266 | + g.P("}") |
| 267 | + g.P() |
| 268 | + |
| 269 | + g.P("type ", streamType, " struct {") |
| 270 | + g.P(grpcPackage.Ident("ClientStream")) |
| 271 | + g.P("}") |
| 272 | + g.P() |
| 273 | + |
| 274 | + if genSend { |
| 275 | + g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {") |
| 276 | + g.P("return x.ClientStream.SendMsg(m)") |
| 277 | + g.P("}") |
| 278 | + g.P() |
| 279 | + } |
| 280 | + if genRecv { |
| 281 | + g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {") |
| 282 | + g.P("m := new(", method.Output.GoIdent, ")") |
| 283 | + g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") |
| 284 | + g.P("return m, nil") |
| 285 | + g.P("}") |
| 286 | + g.P() |
| 287 | + } |
| 288 | + if genCloseAndRecv { |
| 289 | + g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {") |
| 290 | + g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") |
| 291 | + g.P("m := new(", method.Output.GoIdent, ")") |
| 292 | + g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") |
| 293 | + g.P("return m, nil") |
| 294 | + g.P("}") |
| 295 | + g.P() |
| 296 | + } |
| 297 | +} |
| 298 | + |
| 299 | +func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { |
| 300 | + var reqArgs []string |
| 301 | + ret := "error" |
| 302 | + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| 303 | + reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context"))) |
| 304 | + ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)" |
| 305 | + } |
| 306 | + if !method.Desc.IsStreamingClient() { |
| 307 | + reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent)) |
| 308 | + } |
| 309 | + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { |
| 310 | + reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server") |
| 311 | + } |
| 312 | + return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret |
| 313 | +} |
| 314 | + |
| 315 | +func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string { |
| 316 | + service := method.Parent |
| 317 | + hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName) |
| 318 | + |
| 319 | + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| 320 | + g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {") |
| 321 | + g.P("in := new(", method.Input.GoIdent, ")") |
| 322 | + g.P("if err := dec(in); err != nil { return nil, err }") |
| 323 | + g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }") |
| 324 | + g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{") |
| 325 | + g.P("Server: srv,") |
| 326 | + g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",") |
| 327 | + g.P("}") |
| 328 | + g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {") |
| 329 | + g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))") |
| 330 | + g.P("}") |
| 331 | + g.P("return interceptor(ctx, in, info, handler)") |
| 332 | + g.P("}") |
| 333 | + g.P() |
| 334 | + return hname |
| 335 | + } |
| 336 | + streamType := unexport(service.GoName) + method.GoName + "Server" |
| 337 | + g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {") |
| 338 | + if !method.Desc.IsStreamingClient() { |
| 339 | + g.P("m := new(", method.Input.GoIdent, ")") |
| 340 | + g.P("if err := stream.RecvMsg(m); err != nil { return err }") |
| 341 | + g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})") |
| 342 | + } else { |
| 343 | + g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})") |
| 344 | + } |
| 345 | + g.P("}") |
| 346 | + g.P() |
| 347 | + |
| 348 | + genSend := method.Desc.IsStreamingServer() |
| 349 | + genSendAndClose := !method.Desc.IsStreamingServer() |
| 350 | + genRecv := method.Desc.IsStreamingClient() |
| 351 | + |
| 352 | + // Stream auxiliary types and methods. |
| 353 | + g.P("type ", service.GoName, "_", method.GoName, "Server interface {") |
| 354 | + if genSend { |
| 355 | + g.P("Send(*", method.Output.GoIdent, ") error") |
| 356 | + } |
| 357 | + if genSendAndClose { |
| 358 | + g.P("SendAndClose(*", method.Output.GoIdent, ") error") |
| 359 | + } |
| 360 | + if genRecv { |
| 361 | + g.P("Recv() (*", method.Input.GoIdent, ", error)") |
| 362 | + } |
| 363 | + g.P(grpcPackage.Ident("ServerStream")) |
| 364 | + g.P("}") |
| 365 | + g.P() |
| 366 | + |
| 367 | + g.P("type ", streamType, " struct {") |
| 368 | + g.P(grpcPackage.Ident("ServerStream")) |
| 369 | + g.P("}") |
| 370 | + g.P() |
| 371 | + |
| 372 | + if genSend { |
| 373 | + g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {") |
| 374 | + g.P("return x.ServerStream.SendMsg(m)") |
| 375 | + g.P("}") |
| 376 | + g.P() |
| 377 | + } |
| 378 | + if genSendAndClose { |
| 379 | + g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {") |
| 380 | + g.P("return x.ServerStream.SendMsg(m)") |
| 381 | + g.P("}") |
| 382 | + g.P() |
| 383 | + } |
| 384 | + if genRecv { |
| 385 | + g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {") |
| 386 | + g.P("m := new(", method.Input.GoIdent, ")") |
| 387 | + g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }") |
| 388 | + g.P("return m, nil") |
| 389 | + g.P("}") |
| 390 | + g.P() |
| 391 | + } |
| 392 | + |
| 393 | + return hname |
| 394 | +} |
| 395 | + |
| 396 | +const deprecationComment = "// Deprecated: Do not use." |
| 397 | + |
| 398 | +func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } |
0 commit comments