Skip to content

Commit

Permalink
Merge pull request #807 from alnr/cockroach-uuid
Browse files Browse the repository at this point in the history
feat: use gen_random_uuid() for CockroachDB INSERTs
  • Loading branch information
sio4 authored Jan 24, 2023
2 parents fbb24cf + cb7479c commit ec9229d
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 50 deletions.
56 changes: 43 additions & 13 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pop

import (
"bytes"
"errors"
"fmt"
"io"
"net/url"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/gobuffalo/pop/v6/columns"
"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/logging"
"github.com/gofrs/uuid"
_ "github.com/jackc/pgx/v4/stdlib" // Import PostgreSQL driver
"github.com/jmoiron/sqlx"
)
Expand Down Expand Up @@ -77,27 +79,55 @@ func (p *cockroach) Create(c *Connection, model *Model, cols columns.Columns) er
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", p.Quote(model.TableName()), model.IDField())
}
txlog(logging.SQL, c, query, model.Value)
stmt, err := c.Store.PrepareNamed(query)
rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value)
if err != nil {
return err
return fmt.Errorf("named insert: %w", err)
}
id := map[string]interface{}{}
err = stmt.QueryRow(model.Value).MapScan(id)
if err != nil {
if closeErr := stmt.Close(); closeErr != nil {
return fmt.Errorf("failed to close prepared statement: %s: %w", closeErr, err)
defer rows.Close()
if !rows.Next() {
return errors.New("named insert: no rows")
}
var id interface{}
if err := rows.Scan(&id); err != nil {
return fmt.Errorf("named insert: scan: %w", err)
}
model.setID(id)
return nil

case "UUID":
var query string
if model.ID() == emptyUUID {
cols.Remove(model.IDField())
w := cols.Writeable()
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (gen_random_uuid(), %s) RETURNING %s", p.Quote(model.TableName()), model.IDField(), w.QuotedString(p), w.SymbolizedString(), model.IDField())
} else {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (gen_random_uuid()) RETURNING %s", p.Quote(model.TableName()), model.IDField(), model.IDField())
}
return err
} else {
w := cols.Writeable()
w.Add(model.IDField())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
}
txlog(logging.SQL, c, query, model.Value)
rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value)
if err != nil {
return fmt.Errorf("named insert: %w", err)
}
defer rows.Close()
if !rows.Next() {
return errors.New("named insert: no rows")
}
model.setID(id[model.IDField()])
if err := stmt.Close(); err != nil {
return fmt.Errorf("failed to close statement: %w", err)
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return fmt.Errorf("named insert: scan: %w", err)
}
model.setID(id)
return nil
}
return genericCreate(c, model, cols, p)
Expand Down
25 changes: 7 additions & 18 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func genericCreate(c *Connection, model *Model, cols columns.Columns, quoter quo
w := cols.Writeable()
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
txlog(logging.SQL, c, query, model.Value)
res, err := c.Store.NamedExec(query, model.Value)
res, err := c.Store.NamedExecContext(model.ctx, query, model.Value)
if err != nil {
return err
}
Expand Down Expand Up @@ -82,19 +82,8 @@ func genericCreate(c *Connection, model *Model, cols columns.Columns, quoter quo
w.Add(model.IDField())
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
txlog(logging.SQL, c, query, model.Value)
stmt, err := c.Store.PrepareNamed(query)
if err != nil {
return err
}
_, err = stmt.ExecContext(model.ctx, model.Value)
if err != nil {
if closeErr := stmt.Close(); closeErr != nil {
return fmt.Errorf("failed to close prepared statement: %s: %w", closeErr, err)
}
return err
}
if err := stmt.Close(); err != nil {
return fmt.Errorf("failed to close statement: %w", err)
if _, err := c.Store.NamedExecContext(model.ctx, query, model.Value); err != nil {
return fmt.Errorf("named insert: %w", err)
}
return nil
}
Expand All @@ -104,7 +93,7 @@ func genericCreate(c *Connection, model *Model, cols columns.Columns, quoter quo
func genericUpdate(c *Connection, model *Model, cols columns.Columns, quoter quotable) error {
stmt := fmt.Sprintf("UPDATE %s AS %s SET %s WHERE %s", quoter.Quote(model.TableName()), model.Alias(), cols.Writeable().QuotedUpdateString(quoter), model.WhereNamedID())
txlog(logging.SQL, c, stmt, model.ID())
_, err := c.Store.NamedExec(stmt, model.Value)
_, err := c.Store.NamedExecContext(model.ctx, stmt, model.Value)
if err != nil {
return err
}
Expand Down Expand Up @@ -154,14 +143,14 @@ func genericDelete(c *Connection, model *Model, query Query) error {

func genericExec(c *Connection, stmt string, args ...interface{}) (sql.Result, error) {
txlog(logging.SQL, c, stmt, args...)
res, err := c.Store.Exec(stmt, args...)
res, err := c.Store.ExecContext(c.Context(), stmt, args...)
return res, err
}

func genericSelectOne(c *Connection, model *Model, query Query) error {
sqlQuery, args := query.ToSQL(model)
txlog(logging.SQL, query.Connection, sqlQuery, args...)
err := c.Store.Get(model.Value, sqlQuery, args...)
err := c.Store.GetContext(model.ctx, model.Value, sqlQuery, args...)
if err != nil {
return err
}
Expand All @@ -171,7 +160,7 @@ func genericSelectOne(c *Connection, model *Model, query Query) error {
func genericSelectMany(c *Connection, models *Model, query Query) error {
sqlQuery, args := query.ToSQL(models)
txlog(logging.SQL, query.Connection, sqlQuery, args...)
err := c.Store.Select(models.Value, sqlQuery, args...)
err := c.Store.SelectContext(models.ctx, models.Value, sqlQuery, args...)
if err != nil {
return err
}
Expand Down
26 changes: 12 additions & 14 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pop

import (
"errors"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -61,27 +62,24 @@ func (p *postgresql) Create(c *Connection, model *Model, cols columns.Columns) e
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) RETURNING %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES RETURNING %s", p.Quote(model.TableName()), model.IDField())
}
txlog(logging.SQL, c, query, model.Value)
stmt, err := c.Store.PrepareNamed(query)
rows, err := c.Store.NamedQueryContext(model.ctx, query, model.Value)
if err != nil {
return err
return fmt.Errorf("named insert: %w", err)
}
id := map[string]interface{}{}
err = stmt.QueryRow(model.Value).MapScan(id)
if err != nil {
if closeErr := stmt.Close(); closeErr != nil {
return fmt.Errorf("failed to close prepared statement: %s: %w", closeErr, err)
}
return err
defer rows.Close()
if !rows.Next() {
return errors.New("named insert: no rows")
}
model.setID(id[model.IDField()])
if closeErr := stmt.Close(); closeErr != nil {
return fmt.Errorf("failed to close statement: %w", closeErr)
var id interface{}
if err := rows.Scan(&id); err != nil {
return fmt.Errorf("named insert: scan: %w", err)
}
model.setID(id)
return nil
}
return genericCreate(c, model, cols, p)
Expand Down
2 changes: 1 addition & 1 deletion dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (m *sqlite) Create(c *Connection, model *Model, cols columns.Columns) error
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", m.Quote(model.TableName()))
}
txlog(logging.SQL, c, query, model.Value)
res, err := c.Store.NamedExec(query, model.Value)
res, err := c.Store.NamedExecContext(model.ctx, query, model.Value)
if err != nil {
return err
}
Expand Down
7 changes: 6 additions & 1 deletion executors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1405,9 +1405,11 @@ func Test_Create_UUID(t *testing.T) {

count, _ := tx.Count(&Song{})
song := Song{Title: "Automatic Buffalo"}
r.Equal(uuid.Nil, song.ID)
err := tx.Create(&song)
r.NoError(err)
r.NotZero(song.ID)
r.NotEqual(uuid.Nil, song.ID)

ctx, _ := tx.Count(&Song{})
r.Equal(count+1, ctx)
Expand Down Expand Up @@ -1520,10 +1522,13 @@ func Test_UpdateQuery_NoUpdatedAt(t *testing.T) {
}
transaction(func(tx *Connection) {
r := require.New(t)
existing, err := PDB.Count(&NonStandardID{}) // from previous test runs
r.NoError(err)
r.GreaterOrEqual(existing, 0)
r.NoError(PDB.Create(&NonStandardID{OutfacingID: "must-change"}))
count, err := PDB.Where("true").UpdateQuery(&NonStandardID{OutfacingID: "has-changed"}, "id")
r.NoError(err)
r.Equal(int64(1), count)
r.Equal(int64(existing+1), count)
entity := NonStandardID{}
r.NoError(PDB.First(&entity))
r.Equal("has-changed", entity.OutfacingID)
Expand Down
3 changes: 3 additions & 0 deletions model_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ func Test_ModelContext(t *testing.T) {
t.Run("cache_busting", func(t *testing.T) {
r := require.New(t)

r.NoError(PDB.WithContext(context.WithValue(context.Background(), "prefix", "a")).Destroy(&ContextTable{ID: "expectedA"}))
r.NoError(PDB.WithContext(context.WithValue(context.Background(), "prefix", "b")).Destroy(&ContextTable{ID: "expectedB"}))

var expectedA, expectedB ContextTable
expectedA.ID = "expectedA"
expectedB.ID = "expectedB"
Expand Down
6 changes: 3 additions & 3 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ func Test_ToSQL_RawQuery(t *testing.T) {
}

func Test_RawQuery_Empty(t *testing.T) {
Debug = true
defer func() { Debug = false }()

if PDB == nil {
t.Skip("skipping integration tests")
}
t.Run("EmptyQuery", func(t *testing.T) {
r := require.New(t)
transaction(func(tx *Connection) {
Expand Down
2 changes: 2 additions & 0 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type store interface {
Select(interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
NamedExec(string, interface{}) (sql.Result, error)
NamedQuery(query string, arg interface{}) (*sqlx.Rows, error)
Exec(string, ...interface{}) (sql.Result, error)
PrepareNamed(string) (*sqlx.NamedStmt, error)
Transaction() (*Tx, error)
Expand All @@ -24,6 +25,7 @@ type store interface {
SelectContext(context.Context, interface{}, string, ...interface{}) error
GetContext(context.Context, interface{}, string, ...interface{}) error
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
NamedQueryContext(ctx context.Context, query string, arg interface{}) (*sqlx.Rows, error)
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error)
TransactionContext(context.Context) (*Tx, error)
Expand Down
5 changes: 5 additions & 0 deletions tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,8 @@ func (tx *Tx) Transaction() (*Tx, error) {
func (tx *Tx) Close() error {
return nil
}

// Workaround for https://github.com/jmoiron/sqlx/issues/447
func (tx *Tx) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*sqlx.Rows, error) {
return sqlx.NamedQueryContext(ctx, tx, query, arg)
}

0 comments on commit ec9229d

Please sign in to comment.