diff --git a/joiner.go b/joiner.go index e492f80..5e3949c 100644 --- a/joiner.go +++ b/joiner.go @@ -19,3 +19,12 @@ func appendJoin(join Joiner, builder *strings.Builder, args *[]any) { builder.WriteString("\n") *args = append(*args, jArgs...) } + +type joinStringOption struct { + join string + args []any +} + +func (j *joinStringOption) Join() (string, []any) { + return j.join, j.args +} diff --git a/patch_opts.go b/patch_opts.go index 695d4df..84c30e5 100644 --- a/patch_opts.go +++ b/patch_opts.go @@ -34,6 +34,20 @@ func WithWhere(where Wherer) PatchOpt { } } +// WithWhereStr takes a string and args to set the where clause to use in the SQL statement. This is useful when you +// want to use a simple where clause. +// +// Note. The where string should not contain the "WHERE" keyword. We recommend using the WhereTyper interface if you +// want to specify the WHERE type or do a more complex WHERE clause. +func WithWhereStr(where string, args ...any) PatchOpt { + return func(s *SQLPatch) { + appendWhere(&whereStringOption{ + where: where, + args: args, + }, s.whereSql, &s.whereArgs) + } +} + // WithJoin sets the join clause to use in the SQL statement func WithJoin(join Joiner) PatchOpt { return func(s *SQLPatch) { @@ -41,6 +55,20 @@ func WithJoin(join Joiner) PatchOpt { } } +// WithJoinStr takes a string and args to set the join clause to use in the SQL statement. This is useful when you +// want to use a simple join clause. +// +// Note. The join string should not contain the "JOIN" keyword. We recommend using the Joiner interface if you +// want to specify the JOIN type or do a more complex JOIN clause. +func WithJoinStr(join string, args ...any) PatchOpt { + return func(s *SQLPatch) { + appendJoin(&joinStringOption{ + join: join, + args: args, + }, s.joinSql, &s.joinArgs) + } +} + // WithDB sets the database connection to use func WithDB(db *sql.DB) PatchOpt { return func(s *SQLPatch) { diff --git a/sql_test.go b/sql_test.go index 9f35157..72b6f34 100644 --- a/sql_test.go +++ b/sql_test.go @@ -56,6 +56,46 @@ func (s *newSQLPatchSuite) TestNewSQLPatch_Success_MultiFilter() { s.Equal([]any{int64(1), "test"}, patch.args) } +func (s *newSQLPatchSuite) TestNewSQLPatch_WhereString() { + type testObj struct { + Id *int `db:"id_tag"` + Name *string `db:"name_tag"` + } + + obj := testObj{ + Id: ptr(1), + Name: ptr("test"), + } + + patch := NewSQLPatch(obj, WithWhereStr("age = ?", 18)) + + s.Equal([]string{"id_tag = ?", "name_tag = ?"}, patch.fields) + s.Equal([]any{int64(1), "test"}, patch.args) + + s.Equal("AND age = ?\n", patch.whereSql.String()) + s.Equal([]any{18}, patch.whereArgs) +} + +func (s *newSQLPatchSuite) TestNewSQLPatch_JoinString() { + type testObj struct { + Id *int `db:"id_tag"` + Name *string `db:"name_tag"` + } + + obj := testObj{ + Id: ptr(1), + Name: ptr("test"), + } + + patch := NewSQLPatch(obj, WithJoinStr("JOIN table2 ON table1.id = table2.id")) + + s.Equal([]string{"id_tag = ?", "name_tag = ?"}, patch.fields) + s.Equal([]any{int64(1), "test"}, patch.args) + + s.Equal("JOIN table2 ON table1.id = table2.id\n", patch.joinSql.String()) + s.Empty(patch.joinArgs) +} + func (s *newSQLPatchSuite) TestNewSQLPatch_Fields_Args_Getters() { type testObj struct { Id *int `db:"id_tag"` @@ -581,6 +621,46 @@ func (s *generateSQLSuite) TestGenerateSQL_Success() { mw.AssertExpectations(s.T()) } +func (s *generateSQLSuite) TestGenerateSQL_Success_WhereString() { + type testObj struct { + Id *int `db:"id"` + Name *string `db:"name"` + } + + obj := testObj{ + Id: ptr(1), + Name: ptr("test"), + } + + sqlStr, args, err := GenerateSQL(obj, + WithTable("test_table"), + WithWhereStr("age = ?", 18), + ) + s.NoError(err) + s.Equal("UPDATE test_table\nSET id = ?, name = ?\nWHERE (1=1)\nAND (\nage = ?\n)", sqlStr) + s.Equal([]any{int64(1), "test", 18}, args) +} + +func (s *generateSQLSuite) TestGenerateSQL_Success_JoinString() { + type testObj struct { + Id *int `db:"id"` + Name *string `db:"name"` + } + + obj := testObj{ + Id: ptr(1), + Name: ptr("test"), + } + + sqlStr, args, err := GenerateSQL(obj, + WithTable("test_table"), + WithJoinStr("JOIN table2 ON table1.id = table2.id"), + ) + s.NoError(err) + s.Equal("UPDATE test_table\nSET id = ?, name = ?\nJOIN table2 ON table1.id = table2.id\n", sqlStr) + s.Equal([]any{int64(1), "test"}, args) +} + func (s *generateSQLSuite) TestGenerateSQL_Success_NoWhereArgs() { type testObj struct { Id *int `db:"id"` diff --git a/wherer.go b/wherer.go index 39ef3fb..207f14f 100644 --- a/wherer.go +++ b/wherer.go @@ -49,3 +49,12 @@ func appendWhere(where Wherer, builder *strings.Builder, args *[]any) { builder.WriteString("\n") *args = append(*args, fwArgs...) } + +type whereStringOption struct { + where string + args []any +} + +func (w *whereStringOption) Where() (string, []any) { + return w.where, w.args +}