|
| 1 | +package db |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "reflect" |
| 7 | + "strings" |
| 8 | + |
| 9 | + "github.com/domonda/go-sqldb" |
| 10 | + "github.com/domonda/go-sqldb/impl" |
| 11 | +) |
| 12 | + |
| 13 | +func writeInsertQuery(w *strings.Builder, table string, names []string, format sqldb.PlaceholderFormatter) { |
| 14 | + fmt.Fprintf(w, `INSERT INTO %s(`, table) |
| 15 | + for i, name := range names { |
| 16 | + if i > 0 { |
| 17 | + w.WriteByte(',') |
| 18 | + } |
| 19 | + w.WriteByte('"') |
| 20 | + w.WriteString(name) |
| 21 | + w.WriteByte('"') |
| 22 | + } |
| 23 | + w.WriteString(`) VALUES(`) |
| 24 | + for i := range names { |
| 25 | + if i > 0 { |
| 26 | + w.WriteByte(',') |
| 27 | + } |
| 28 | + w.WriteString(format.Placeholder(i)) |
| 29 | + } |
| 30 | + w.WriteByte(')') |
| 31 | +} |
| 32 | + |
| 33 | +func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { |
| 34 | + v := reflect.ValueOf(rowStruct) |
| 35 | + for v.Kind() == reflect.Ptr && !v.IsNil() { |
| 36 | + v = v.Elem() |
| 37 | + } |
| 38 | + switch { |
| 39 | + case v.Kind() == reflect.Ptr && v.IsNil(): |
| 40 | + return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) |
| 41 | + case v.Kind() != reflect.Struct: |
| 42 | + return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) |
| 43 | + } |
| 44 | + |
| 45 | + columns, _, vals = impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) |
| 46 | + return columns, vals, nil |
| 47 | +} |
| 48 | + |
| 49 | +// Insert a new row into table using the values. |
| 50 | +func Insert(ctx context.Context, table string, values sqldb.Values) error { |
| 51 | + if len(values) == 0 { |
| 52 | + return fmt.Errorf("Insert into table %s: no values", table) |
| 53 | + } |
| 54 | + conn := Conn(ctx) |
| 55 | + |
| 56 | + var query strings.Builder |
| 57 | + names, vals := values.Sorted() |
| 58 | + writeInsertQuery(&query, table, names, conn) |
| 59 | + |
| 60 | + err := conn.Exec(query.String(), vals...) |
| 61 | + if err != nil { |
| 62 | + return wrapErrorWithQuery(err, query.String(), vals, conn) |
| 63 | + } |
| 64 | + return nil |
| 65 | +} |
| 66 | + |
| 67 | +// InsertUnique inserts a new row into table using the passed values |
| 68 | +// or does nothing if the onConflict statement applies. |
| 69 | +// Returns if a row was inserted. |
| 70 | +func InsertUnique(ctx context.Context, table string, values sqldb.Values, onConflict string) (inserted bool, err error) { |
| 71 | + if len(values) == 0 { |
| 72 | + return false, fmt.Errorf("InsertUnique into table %s: no values", table) |
| 73 | + } |
| 74 | + conn := Conn(ctx) |
| 75 | + |
| 76 | + if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") { |
| 77 | + onConflict = onConflict[1 : len(onConflict)-1] |
| 78 | + } |
| 79 | + |
| 80 | + var query strings.Builder |
| 81 | + names, vals := values.Sorted() |
| 82 | + writeInsertQuery(&query, table, names, conn) |
| 83 | + fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) |
| 84 | + |
| 85 | + err = conn.QueryRow(query.String(), vals...).Scan(&inserted) |
| 86 | + err = sqldb.ReplaceErrNoRows(err, nil) |
| 87 | + if err != nil { |
| 88 | + return false, wrapErrorWithQuery(err, query.String(), vals, conn) |
| 89 | + } |
| 90 | + return inserted, err |
| 91 | +} |
| 92 | + |
| 93 | +// InsertReturning inserts a new row into table using values |
| 94 | +// and returns values from the inserted row listed in returning. |
| 95 | +func InsertReturning(ctx context.Context, table string, values sqldb.Values, returning string) sqldb.RowScanner { |
| 96 | + if len(values) == 0 { |
| 97 | + return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) |
| 98 | + } |
| 99 | + conn := Conn(ctx) |
| 100 | + |
| 101 | + var query strings.Builder |
| 102 | + names, vals := values.Sorted() |
| 103 | + writeInsertQuery(&query, table, names, conn) |
| 104 | + query.WriteString(" RETURNING ") |
| 105 | + query.WriteString(returning) |
| 106 | + return conn.QueryRow(query.String(), vals...) // TODO wrap error with query |
| 107 | +} |
| 108 | + |
| 109 | +// InsertStruct inserts a new row into table using the connection's |
| 110 | +// StructFieldMapper to map struct fields to column names. |
| 111 | +// Optional ColumnFilter can be passed to ignore mapped columns. |
| 112 | +func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { |
| 113 | + conn := Conn(ctx) |
| 114 | + columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns) |
| 115 | + if err != nil { |
| 116 | + return err |
| 117 | + } |
| 118 | + |
| 119 | + var query strings.Builder |
| 120 | + writeInsertQuery(&query, table, columns, conn) |
| 121 | + |
| 122 | + err = conn.Exec(query.String(), vals...) |
| 123 | + if err != nil { |
| 124 | + return wrapErrorWithQuery(err, query.String(), vals, conn) |
| 125 | + } |
| 126 | + return nil |
| 127 | +} |
| 128 | + |
| 129 | +// InsertUniqueStruct inserts a new row into table using the connection's |
| 130 | +// StructFieldMapper to map struct fields to column names. |
| 131 | +// Optional ColumnFilter can be passed to ignore mapped columns. |
| 132 | +// Does nothing if the onConflict statement applies |
| 133 | +// and returns if a row was inserted. |
| 134 | +func InsertUniqueStruct(ctx context.Context, table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { |
| 135 | + conn := Conn(ctx) |
| 136 | + columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns) |
| 137 | + if err != nil { |
| 138 | + return false, err |
| 139 | + } |
| 140 | + |
| 141 | + if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") { |
| 142 | + onConflict = onConflict[1 : len(onConflict)-1] |
| 143 | + } |
| 144 | + |
| 145 | + var query strings.Builder |
| 146 | + writeInsertQuery(&query, table, columns, conn) |
| 147 | + fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) |
| 148 | + |
| 149 | + err = conn.QueryRow(query.String(), vals...).Scan(&inserted) |
| 150 | + err = sqldb.ReplaceErrNoRows(err, nil) |
| 151 | + if err != nil { |
| 152 | + return false, wrapErrorWithQuery(err, query.String(), vals, conn) |
| 153 | + } |
| 154 | + return inserted, err |
| 155 | +} |
| 156 | + |
| 157 | +// InsertStructs inserts a slice or array of structs |
| 158 | +// as new rows into table using the connection's |
| 159 | +// StructFieldMapper to map struct fields to column names. |
| 160 | +// Optional ColumnFilter can be passed to ignore mapped columns. |
| 161 | +// |
| 162 | +// TODO optimized version with single query if possible |
| 163 | +// split into multiple queries depending or maxArgs for query |
| 164 | +func InsertStructs(ctx context.Context, table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { |
| 165 | + v := reflect.ValueOf(rowStructs) |
| 166 | + if k := v.Type().Kind(); k != reflect.Slice && k != reflect.Array { |
| 167 | + return fmt.Errorf("InsertStructs expects a slice or array as rowStructs, got %T", rowStructs) |
| 168 | + } |
| 169 | + numRows := v.Len() |
| 170 | + return Transaction(ctx, func(ctx context.Context) error { |
| 171 | + for i := 0; i < numRows; i++ { |
| 172 | + err := InsertStruct(ctx, table, v.Index(i).Interface(), ignoreColumns...) |
| 173 | + if err != nil { |
| 174 | + return err |
| 175 | + } |
| 176 | + } |
| 177 | + return nil |
| 178 | + }) |
| 179 | +} |
0 commit comments