diff --git a/README.md b/README.md index 477ee32..f19e3df 100644 --- a/README.md +++ b/README.md @@ -72,4 +72,6 @@ Other supported options are: * `enum_as_int_or_string` * when set to `true`, the openapi schema will include `x-kubernetes-int-or-string` on enums. * `additional_empty_schemas` - * a `+` separated list of message names (`core.solo.io.Status`), whose generated schema should be an empty object that accepts all values. \ No newline at end of file + * a `+` separated list of message names (`core.solo.io.Status`), whose generated schema should be an empty object that accepts all values. +* `strict_proto3_optional` + * when set to `true`, the openapi schema will include `required` for object properties that were not explicitly marked as `optional` in proto3 field definition. \ No newline at end of file diff --git a/go.mod b/go.mod index 942ce10..7ebd1b6 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,13 @@ go 1.18 require ( github.com/getkin/kin-openapi v0.80.0 github.com/ghodss/yaml v1.0.0 - github.com/golang/protobuf v1.3.2 + github.com/golang/protobuf v1.5.2 ) require ( github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/swag v0.19.5 // indirect github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e // indirect + google.golang.org/protobuf v1.26.0 // indirect gopkg.in/yaml.v2 v2.3.0 // indirect ) diff --git a/go.sum b/go.sum index 38d55f5..be1f732 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,11 @@ github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUe github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -26,6 +29,11 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index c36f798..fc1967e 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,6 @@ package main import ( "fmt" - "strings" "github.com/solo-io/protoc-gen-openapi/pkg/protocgen" "github.com/solo-io/protoc-gen-openapi/pkg/protomodel" @@ -24,94 +23,13 @@ import ( plugin "github.com/golang/protobuf/protoc-gen-go/plugin" ) -// Breaks the comma-separated list of key=value pairs -// in the parameter string into an easy to use map. -func extractParams(parameter string) map[string]string { - m := make(map[string]string) - for _, p := range strings.Split(parameter, ",") { - if p == "" { - continue - } - - if i := strings.Index(p, "="); i < 0 { - m[p] = "" - } else { - m[p[0:i]] = p[i+1:] - } - } - - return m -} - func generate(request plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorResponse, error) { - perFile := false - singleFile := false - yaml := false - useRef := false - includeDescription := true - enumAsIntOrString := false - var messagesWithEmptySchema []string - - p := extractParams(request.GetParameter()) - for k, v := range p { - if k == "per_file" { - switch strings.ToLower(v) { - case "true": - perFile = true - case "false": - perFile = false - default: - return nil, fmt.Errorf("unknown value '%s' for per_file", v) - } - } else if k == "single_file" { - switch strings.ToLower(v) { - case "true": - if perFile { - return nil, fmt.Errorf("output is already to be generated per file, cannot output to a single file") - } - singleFile = true - case "false": - singleFile = false - default: - return nil, fmt.Errorf("unknown value '%s' for single_file", v) - } - } else if k == "yaml" { - yaml = true - } else if k == "use_ref" { - switch strings.ToLower(v) { - case "true": - useRef = true - case "false": - useRef = false - default: - return nil, fmt.Errorf("unknown value '%s' for use_ref", v) - } - } else if k == "include_description" { - switch strings.ToLower(v) { - case "true": - includeDescription = true - case "false": - includeDescription = false - default: - return nil, fmt.Errorf("unknown value '%s' for include_description", v) - } - } else if k == "enum_as_int_or_string" { - switch strings.ToLower(v) { - case "true": - enumAsIntOrString = true - case "false": - enumAsIntOrString = false - default: - return nil, fmt.Errorf("unknown value '%s' for enum_as_int_or_string", v) - } - } else if k == "additional_empty_schema" { - messagesWithEmptySchema = strings.Split(v, "+") - } else { - return nil, fmt.Errorf("unknown argument '%s' specified", k) - } + options := newGenerationOptions() + if err := options.parseParameters(request.GetParameter()); err != nil { + return nil, err } - m := protomodel.NewModel(&request, perFile) + m := protomodel.NewModel(&request, options.perFile) filesToGen := make(map[*protomodel.FileDescriptor]bool) for _, fileName := range request.FileToGenerate { @@ -123,18 +41,10 @@ func generate(request plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorRespons } descriptionConfiguration := &DescriptionConfiguration{ - IncludeDescriptionInSchema: includeDescription, + IncludeDescriptionInSchema: options.includeDescription, } - g := newOpenAPIGenerator( - m, - perFile, - singleFile, - yaml, - useRef, - descriptionConfiguration, - enumAsIntOrString, - messagesWithEmptySchema) + g := newOpenAPIGenerator(options, m, descriptionConfiguration) return g.generateOutput(filesToGen) } diff --git a/openapiGenerator.go b/openapiGenerator.go index 89033a5..a02691c 100644 --- a/openapiGenerator.go +++ b/openapiGenerator.go @@ -81,12 +81,13 @@ var specialSoloTypes = map[string]openapi3.Schema{ } type openapiGenerator struct { - buffer bytes.Buffer - model *protomodel.Model - perFile bool - singleFile bool - yaml bool - useRef bool + buffer bytes.Buffer + model *protomodel.Model + perFile bool + singleFile bool + yaml bool + useRef bool + strictProto3Optional bool // transient state as individual files are processed currentPackage *protomodel.PackageDescriptor @@ -111,24 +112,19 @@ type DescriptionConfiguration struct { } func newOpenAPIGenerator( + options *generationOptions, model *protomodel.Model, - perFile bool, - singleFile bool, - yaml bool, - useRef bool, - descriptionConfiguration *DescriptionConfiguration, - enumAsIntOrString bool, - messagesWithEmptySchema []string, -) *openapiGenerator { + descriptionConfiguration *DescriptionConfiguration) *openapiGenerator { return &openapiGenerator{ model: model, - perFile: perFile, - singleFile: singleFile, - yaml: yaml, - useRef: useRef, + perFile: options.perFile, + singleFile: options.singleFile, + yaml: options.yaml, + useRef: options.useRef, + strictProto3Optional: options.strictProto3Optional, descriptionConfiguration: descriptionConfiguration, - enumAsIntOrString: enumAsIntOrString, - customSchemasByMessageName: buildCustomSchemasByMessageName(messagesWithEmptySchema), + enumAsIntOrString: options.enumAsIntOrString, + customSchemasByMessageName: buildCustomSchemasByMessageName(options.messagesWithEmptySchema), } } @@ -163,7 +159,8 @@ func buildCustomSchemasByMessageName(messagesWithEmptySchema []string) map[strin } func (g *openapiGenerator) generateOutput(filesToGen map[*protomodel.FileDescriptor]bool) (*plugin.CodeGeneratorResponse, error) { - response := plugin.CodeGeneratorResponse{} + supportedFeatures := uint64(plugin.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) + response := plugin.CodeGeneratorResponse{SupportedFeatures: &supportedFeatures} if g.singleFile { g.generateSingleFileOutput(filesToGen, &response) @@ -393,6 +390,9 @@ func (g *openapiGenerator) generateMessageSchema(message *protomodel.MessageDesc for _, field := range message.Fields { sr := g.fieldTypeRef(field) o.WithProperty(g.fieldName(field), sr.Value) + if g.strictProto3Optional && !field.IsOptional() { + o.Required = append(o.Required, g.fieldName(field)) + } } return o diff --git a/options.go b/options.go new file mode 100644 index 0000000..b6ac1dd --- /dev/null +++ b/options.go @@ -0,0 +1,97 @@ +package main + +import ( + "fmt" + "strconv" + "strings" +) + +type generationOptions struct { + perFile bool + singleFile bool + yaml bool + useRef bool + includeDescription bool + enumAsIntOrString bool + messagesWithEmptySchema []string + strictProto3Optional bool +} + +func newGenerationOptions() *generationOptions { + return &generationOptions{ + includeDescription: true, + } +} + +func (o *generationOptions) parseParameters(args string) error { + p := extractParams(args) + for k, v := range p { + if k == "per_file" { + if val, err := strconv.ParseBool(v); err != nil { + return fmt.Errorf("unknown value '%s' for per_file", v) + } else { + o.perFile = val + } + } else if k == "single_file" { + if val, err := strconv.ParseBool(v); err != nil { + return fmt.Errorf("unknown value '%s' for single_file", v) + } else { + o.singleFile = val + } + if o.perFile { + return fmt.Errorf("output is already to be generated per file, cannot output to a single file") + } + } else if k == "yaml" { + o.yaml = true + } else if k == "use_ref" { + if val, err := strconv.ParseBool(v); err != nil { + return fmt.Errorf("unknown value '%s' for use_ref", v) + } else { + o.useRef = val + } + } else if k == "include_description" { + if val, err := strconv.ParseBool(v); err != nil { + return fmt.Errorf("unknown value '%s' for include_description", v) + } else { + o.includeDescription = val + } + } else if k == "enum_as_int_or_string" { + if val, err := strconv.ParseBool(v); err != nil { + return fmt.Errorf("unknown value '%s' for enum_as_int_or_string", v) + } else { + o.enumAsIntOrString = val + } + } else if k == "additional_empty_schema" { + o.messagesWithEmptySchema = strings.Split(v, "+") + } else if k == "strict_proto3_optional" { + if val, err := strconv.ParseBool(v); err != nil { + return fmt.Errorf("unknown value '%s' for strict_proto3_optional", v) + } else { + o.strictProto3Optional = val + } + } else { + return fmt.Errorf("unknown argument '%s' specified", k) + } + } + + return nil +} + +// Breaks the comma-separated list of key=value pairs +// in the parameter string into an easy to use map. +func extractParams(parameter string) map[string]string { + m := make(map[string]string) + for _, p := range strings.Split(parameter, ",") { + if p == "" { + continue + } + + if i := strings.Index(p, "="); i < 0 { + m[p] = "" + } else { + m[p[0:i]] = p[i+1:] + } + } + + return m +} diff --git a/pkg/protomodel/messageDescriptor.go b/pkg/protomodel/messageDescriptor.go index f423fff..77700dc 100644 --- a/pkg/protomodel/messageDescriptor.go +++ b/pkg/protomodel/messageDescriptor.go @@ -76,3 +76,7 @@ func newMessageDescriptor(desc *descriptor.DescriptorProto, parent *MessageDescr func (f *FieldDescriptor) IsRepeated() bool { return f.Label != nil && *f.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED } + +func (f *FieldDescriptor) IsOptional() bool { + return f.GetProto3Optional() +}