diff --git a/loader.go b/loader.go index 0f2e077..55260c8 100644 --- a/loader.go +++ b/loader.go @@ -34,35 +34,38 @@ func (s *SQLPatch) loadDiff(old, newT any) error { nElem := reflect.ValueOf(newT).Elem() for i := 0; i < oElem.NumField(); i++ { + oField := oElem.Field(i) + nField := nElem.Field(i) + // Include only exported fields - if !oElem.Field(i).CanSet() || !nElem.Field(i).CanSet() { + if !oField.CanSet() || !nField.CanSet() { continue } // Handle embedded structs (Anonymous fields) if oElem.Type().Field(i).Anonymous { // If the embedded field is a pointer, dereference it - if oElem.Field(i).Kind() == reflect.Ptr { - if !oElem.Field(i).IsNil() && !nElem.Field(i).IsNil() { // If both are not nil, we need to recursively call LoadDiff - if err := s.loadDiff(oElem.Field(i).Interface(), nElem.Field(i).Interface()); err != nil { + if oField.Kind() == reflect.Ptr { + if !oField.IsNil() && !nField.IsNil() { // If both are not nil, we need to recursively call LoadDiff + if err := s.loadDiff(oField.Interface(), nField.Interface()); err != nil { return err } - } else if nElem.Field(i).IsValid() && !nElem.Field(i).IsNil() { - oElem.Field(i).Set(nElem.Field(i)) + } else if nElem.Field(i).IsValid() && !nField.IsNil() { + oField.Set(nField) } continue } - if err := s.loadDiff(oElem.Field(i).Addr().Interface(), nElem.Field(i).Addr().Interface()); err != nil { + if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil { return err } continue } // If the field is a struct, we need to recursively call LoadDiff - if oElem.Field(i).Kind() == reflect.Struct { - if err := s.loadDiff(oElem.Field(i).Addr().Interface(), nElem.Field(i).Addr().Interface()); err != nil { + if oField.Kind() == reflect.Struct { + if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil { return err } continue @@ -73,13 +76,15 @@ func (s *SQLPatch) loadDiff(old, newT any) error { continue } + patcherOptsTag := oElem.Type().Field(i).Tag.Get(TagOptsName) + // Compare the old and new fields. // // New fields take priority over old fields if they are provided based on the configuration. - if nElem.Field(i).Kind() != reflect.Ptr && (!nElem.Field(i).IsZero() || s.includeZeroValues) { - oElem.Field(i).Set(nElem.Field(i)) - } else if nElem.Field(i).Kind() == reflect.Ptr && (!nElem.Field(i).IsNil() || s.includeNilValues) { + if nElem.Field(i).Kind() != reflect.Ptr && (!nField.IsZero() || s.shouldIncludeZero(patcherOptsTag)) { oElem.Field(i).Set(nElem.Field(i)) + } else if nElem.Field(i).Kind() == reflect.Ptr && (!nField.IsNil() || s.shouldIncludeNil(patcherOptsTag)) { + oField.Set(nElem.Field(i)) } } diff --git a/loader_test.go b/loader_test.go index 3e49bc2..4cc35be 100644 --- a/loader_test.go +++ b/loader_test.go @@ -48,28 +48,92 @@ func (s *loadDiffSuite) TestLoadDiff_Success() { s.Equal(26, old.Age) } -func (s *loadDiffSuite) TestLoadDiff_Success_Pointed_Fields() { - s.patch.includeNilValues = true - +func (s *loadDiffSuite) TestLoadDiff_Success_StructOpt_IncludeNilField() { type testStruct struct { - Name *string - Age *int + Name string + Age *int `patcher:"nil"` } old := testStruct{ - Name: ptr("John"), + Name: "John", Age: ptr(25), } n := testStruct{ - Name: ptr("John Smith"), - Age: ptr(26), + Name: "John Smith", + Age: nil, + } + + err := s.patch.loadDiff(&old, &n) + s.NoError(err) + s.Equal("John Smith", old.Name) + s.Nil(old.Age) +} + +func (s *loadDiffSuite) TestLoadDiff_Success_StructOpt_IncludeZeroField() { + type testStruct struct { + Name string + Age int `patcher:"zero"` + } + + old := testStruct{ + Name: "John", + Age: 25, + } + + n := testStruct{ + Name: "John Smith", + Age: 0, } err := s.patch.loadDiff(&old, &n) s.NoError(err) - s.Equal("John Smith", *old.Name) - s.Equal(26, *old.Age) + s.Equal("John Smith", old.Name) + s.Equal(0, old.Age) +} + +func (s *loadDiffSuite) TestLoadDiff_Success_NoStructOpts() { + type testStruct struct { + Name string + Age int + } + + old := testStruct{ + Name: "John", + Age: 25, + } + + n := testStruct{ + Name: "John Smith", + Age: 0, + } + + err := s.patch.loadDiff(&old, &n) + s.NoError(err) + s.Equal("John Smith", old.Name) + s.Equal(25, old.Age) +} + +func (s *loadDiffSuite) TestLoadDiff_Success_Pointed_Fields() { + type testStruct struct { + Name string + Age int + } + + old := testStruct{ + Name: "John", + Age: 25, + } + + n := testStruct{ + Name: "John Smith", + Age: 0, + } + + err := s.patch.loadDiff(&old, &n) + s.NoError(err) + s.Equal("John Smith", old.Name) + s.Equal(25, old.Age) } func (s *loadDiffSuite) TestLoadDiff_Success_ZeroValue() { diff --git a/patch.go b/patch.go index 4196677..9dffffe 100644 --- a/patch.go +++ b/patch.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "reflect" + "slices" "strings" ) @@ -132,3 +133,33 @@ func (s *SQLPatch) validateSQLGen() error { return nil } + +func (s *SQLPatch) shouldIncludeNil(tag string) bool { + if s.includeNilValues { + return true + } + + if tag != "" { + tags := strings.Split(tag, TagOptSeparator) + if slices.Contains(tags, TagOptAllowNil) { + return true + } + } + + return false +} + +func (s *SQLPatch) shouldIncludeZero(tag string) bool { + if s.includeZeroValues { + return true + } + + if tag != "" { + tagOpts := strings.Split(tag, TagOptSeparator) + if slices.Contains(tagOpts, TagOptAllowZero) { + return true + } + } + + return false +} diff --git a/patch_opts.go b/patch_opts.go index a2a16f7..b050b81 100644 --- a/patch_opts.go +++ b/patch_opts.go @@ -9,6 +9,8 @@ const ( TagOptsName = "patcher" TagOptSeparator = "," TagOptSkip = "-" + TagOptAllowNil = "nil" + TagOptAllowZero = "zero" ) type PatchOpt func(*SQLPatch) diff --git a/sql.go b/sql.go index ac38723..67de572 100644 --- a/sql.go +++ b/sql.go @@ -53,16 +53,17 @@ func (s *SQLPatch) patchGen(resource any) { continue } + patcherOptsTag := fType.Tag.Get(TagOptsName) + // Skip fields that are to be ignored if s.checkSkipField(fType) { continue - } else if fVal.Kind() == reflect.Ptr && (fVal.IsNil() && !s.includeNilValues) { + } else if fVal.Kind() == reflect.Ptr && (fVal.IsNil() && !s.shouldIncludeNil(patcherOptsTag)) { continue - } else if fVal.Kind() != reflect.Ptr && (fVal.IsZero() && !s.includeZeroValues) { + } else if fVal.Kind() != reflect.Ptr && (fVal.IsZero() && !s.shouldIncludeZero(patcherOptsTag)) { continue } - patcherOptsTag := fType.Tag.Get(TagOptsName) if patcherOptsTag != "" { patcherOpts := strings.Split(patcherOptsTag, TagOptSeparator) if slices.Contains(patcherOpts, TagOptSkip) { @@ -79,7 +80,7 @@ func (s *SQLPatch) patchGen(resource any) { s.fields = append(s.fields, tag+" = ?") } - if fVal.Kind() == reflect.Ptr && fVal.IsNil() && s.includeNilValues { + if fVal.Kind() == reflect.Ptr && fVal.IsNil() && s.shouldIncludeNil(patcherOptsTag) { s.args = append(s.args, nil) addField() continue @@ -215,9 +216,11 @@ func NewDiffSQLPatch[T any](old, newT *T, opts ...PatchOpt) (*SQLPatch, error) { oldField := oldElem.Field(i) copyField := oldCopyElem.Field(i) - if oldField.Kind() == reflect.Ptr && (oldField.IsNil() && copyField.IsNil() && !patch.includeZeroValues) { + patcherOptsTag := oldElem.Type().Field(i).Tag.Get(TagOptsName) + + if oldField.Kind() == reflect.Ptr && (oldField.IsNil() && copyField.IsNil() && !patch.shouldIncludeNil(patcherOptsTag)) { continue - } else if oldField.Kind() != reflect.Ptr && (oldField.IsZero() && copyField.IsZero() && !patch.includeZeroValues) { + } else if oldField.Kind() != reflect.Ptr && (oldField.IsZero() && copyField.IsZero() && !patch.shouldIncludeZero(patcherOptsTag)) { continue } diff --git a/sql_test.go b/sql_test.go index 7d10e93..c2ef63b 100644 --- a/sql_test.go +++ b/sql_test.go @@ -33,6 +33,40 @@ func (s *newSQLPatchSuite) TestNewSQLPatch_Success() { s.Equal([]any{int64(1), "test"}, patch.args) } +func (s *newSQLPatchSuite) TestNewSQLPatch_Success_Struct_opt_IncludeNilFields() { + type testObj struct { + Id *int `db:"id_tag"` + Name *string `db:"name_tag" patcher:"nil"` + } + + obj := testObj{ + Id: ptr(1), + Name: nil, + } + + patch := NewSQLPatch(obj) + + s.Equal([]string{"id_tag = ?", "name_tag = ?"}, patch.fields) + s.Equal([]any{int64(1), nil}, patch.args) +} + +func (s *newSQLPatchSuite) TestNewSQLPatch_Success_Struct_opt_IncludeZeroFields() { + type testObj struct { + Id int `db:"id_tag"` + Name string `db:"name_tag" patcher:"zero"` + } + + obj := testObj{ + Id: 1, + Name: "", + } + + patch := NewSQLPatch(obj) + + s.Equal([]string{"id_tag = ?", "name_tag = ?"}, patch.fields) + s.Equal([]any{int64(1), ""}, patch.args) +} + func (s *newSQLPatchSuite) TestNewSQLPatch_Skip() { type testObj struct { Id *int `db:"id_tag" patcher:"-"` @@ -348,6 +382,56 @@ func (s *generateSQLSuite) TestGenerateSQL_Success() { mw.AssertExpectations(s.T()) } +func (s *generateSQLSuite) TestGenerateSQL_Success_Stuct_opt_IncludeNilFields() { + type testObj struct { + Id *int `db:"id"` + Name *string `db:"name" patcher:"nil"` + } + + obj := testObj{ + Id: ptr(1), + Name: nil, + } + + mw := NewMockWherer(s.T()) + mw.On("Where").Return("age = ?", []any{18}) + + sqlStr, args, err := GenerateSQL(obj, + WithTable("test_table"), + WithWhere(mw), + ) + s.NoError(err) + s.Equal("UPDATE test_table\nSET id = ?, name = ?\nWHERE (1=1)\nAND (\nage = ?\n)", sqlStr) + s.Equal([]any{int64(1), nil, 18}, args) + + mw.AssertExpectations(s.T()) +} + +func (s *generateSQLSuite) TestGenerateSQL_Success_Struct_opt_IncludeZeroFields() { + type testObj struct { + Id int `db:"id"` + Name string `db:"name" patcher:"zero"` + } + + obj := testObj{ + Id: 1, + Name: "", + } + + mw := NewMockWherer(s.T()) + mw.On("Where").Return("age = ?", []any{18}) + + sqlStr, args, err := GenerateSQL(obj, + WithTable("test_table"), + WithWhere(mw), + ) + s.NoError(err) + s.Equal("UPDATE test_table\nSET id = ?, name = ?\nWHERE (1=1)\nAND (\nage = ?\n)", sqlStr) + s.Equal([]any{int64(1), "", 18}, args) + + mw.AssertExpectations(s.T()) +} + func (s *generateSQLSuite) TestGenerateSQL_Success_multipleWhere() { type testObj struct { Id *int `db:"id"` @@ -916,6 +1000,54 @@ func (s *NewDiffSQLPatchSuite) TestNewDiffSQLPatch_Success() { s.Equal([]any{int64(2), "test2"}, patch.args) } +func (s *NewDiffSQLPatchSuite) TestNewDiffSQLPatch_Success_StructOpt_IncludeNilFields() { + type testObj struct { + Id *int `db:"id"` + Name *string `db:"name" patcher:"nil"` + } + + obj := testObj{ + Id: ptr(1), + Name: ptr("test"), + } + + obj2 := testObj{ + Id: ptr(2), + Name: nil, + } + + patch, err := NewDiffSQLPatch(&obj, &obj2) + s.NoError(err) + + s.NotNil(patch) + s.Equal([]string{"id = ?", "name = ?"}, patch.fields) + s.Equal([]any{int64(2), nil}, patch.args) +} + +func (s *NewDiffSQLPatchSuite) TestNewDiffSQLPatch_Success_StructOpt_IncludeZeroFields() { + type testObj struct { + Id int `db:"id"` + Name string `db:"name" patcher:"zero"` + } + + obj := testObj{ + Id: 1, + Name: "test", + } + + obj2 := testObj{ + Id: 2, + Name: "", + } + + patch, err := NewDiffSQLPatch(&obj, &obj2) + s.NoError(err) + + s.NotNil(patch) + s.Equal([]string{"id = ?", "name = ?"}, patch.fields) + s.Equal([]any{int64(2), ""}, patch.args) +} + func (s *NewDiffSQLPatchSuite) TestNewDiffSQLPatch_Success_singleFieldUpdated() { type testObj struct { Id *int `db:"id"`