From 94f05ee02c7dfac7aea9cdc86e88a084d9e8262b Mon Sep 17 00:00:00 2001 From: Kevin Valk Date: Thu, 24 Aug 2023 20:05:01 +0200 Subject: [PATCH] feat: support sqlc.embed --- internal/gen.go | 102 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 9 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index ebe34b0..65d5d87 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -53,6 +53,8 @@ type Field struct { Name string Type pyType Comment string + // EmbedFields contains the embedded fields that require scanning. + EmbedFields []Field } type Struct struct { @@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { call := &pyast.Call{ Func: v.Annotation(), } - for i, f := range v.Struct.Fields { - call.Keywords = append(call.Keywords, &pyast.Keyword{ - Arg: f.Name, - Value: subscriptNode( + rowIndex := 0 // We need to keep track of the index in the row variable. + for _, f := range v.Struct.Fields { + + var valueNode *pyast.Node + // Check if we are using sqlc.embed, if so we need to create a new object. + if len(f.EmbedFields) > 0 { + // We keep this separate so we can easily add all arguments. + embed_call := &pyast.Call{Func: f.Type.Annotation()} + + // Now add all field Initializers for the embedded model that index into the original row. + for i, embedField := range f.EmbedFields { + embed_call.Keywords = append(embed_call.Keywords, &pyast.Keyword{ + Arg: embedField.Name, + Value: subscriptNode( + rowVar, + constantInt(rowIndex+i), + ), + }) + } + + valueNode = &pyast.Node{ + Node: &pyast.Node_Call{ + Call: embed_call, + }, + } + + rowIndex += len(f.EmbedFields) + } else { + valueNode = subscriptNode( rowVar, - constantInt(i), - ), - }) + constantInt(rowIndex), + ) + rowIndex++ + } + + call.Keywords = append(call.Keywords, &pyast.Keyword{Arg: f.Name, Value: valueNode}) } return &pyast.Node{ Node: &pyast.Node_Call{ @@ -336,6 +366,47 @@ func paramName(p *plugin.Parameter) string { type pyColumn struct { id int32 *plugin.Column + embed *pyEmbed +} + +type pyEmbed struct { + modelType string + modelName string + fields []Field +} + +// Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126 +// look through all the structs and attempt to find a matching one to embed +// We need the name of the struct and its field names. +func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *pyEmbed { + if embed == nil { + return nil + } + + for _, s := range structs { + embedSchema := defaultSchema + if embed.Schema != "" { + embedSchema = embed.Schema + } + + // compare the other attributes + if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema { + continue + } + + fields := make([]Field, len(s.Fields)) + for i, f := range s.Fields { + fields[i] = f + } + + return &pyEmbed{ + modelType: s.Name, + modelName: s.Name, + fields: fields, + } + } + + return nil } func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn) *Struct { @@ -359,10 +430,22 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn if suffix > 0 { fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } - gs.Fields = append(gs.Fields, Field{ + + f := Field{ Name: fieldName, Type: makePyType(req, c.Column), - }) + } + + if c.embed != nil { + f.Type = pyType{ + InnerType: "models." + modelName(c.embed.modelType, req.Settings), + IsArray: false, + IsNull: false, + } + f.EmbedFields = c.embed.fields + } + + gs.Fields = append(gs.Fields, f) seen[colName]++ } return &gs @@ -476,6 +559,7 @@ func buildQueries(conf Config, req *plugin.CodeGenRequest, structs []Struct) ([] columns = append(columns, pyColumn{ id: int32(i), Column: c, + embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), }) } gs = columnsToStruct(req, query.Name+"Row", columns)