Skip to content

Commit

Permalink
feat(opts): Allowing singluar field configuration for nillable and ze…
Browse files Browse the repository at this point in the history
…roable (#62)

* Adding tests and fixing misses

* Adding loader test
  • Loading branch information
Jacobbrewer1 authored Oct 31, 2024
1 parent 1ad2747 commit 4cf97e3
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 28 deletions.
29 changes: 17 additions & 12 deletions loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,38 @@ func (s *SQLPatch) loadDiff(old, newT any) error {
nElem := reflect.ValueOf(newT).Elem()

for i := 0; i < oElem.NumField(); i++ {
oField := oElem.Field(i)
nField := nElem.Field(i)

// Include only exported fields
if !oElem.Field(i).CanSet() || !nElem.Field(i).CanSet() {
if !oField.CanSet() || !nField.CanSet() {
continue
}

// Handle embedded structs (Anonymous fields)
if oElem.Type().Field(i).Anonymous {
// If the embedded field is a pointer, dereference it
if oElem.Field(i).Kind() == reflect.Ptr {
if !oElem.Field(i).IsNil() && !nElem.Field(i).IsNil() { // If both are not nil, we need to recursively call LoadDiff
if err := s.loadDiff(oElem.Field(i).Interface(), nElem.Field(i).Interface()); err != nil {
if oField.Kind() == reflect.Ptr {
if !oField.IsNil() && !nField.IsNil() { // If both are not nil, we need to recursively call LoadDiff
if err := s.loadDiff(oField.Interface(), nField.Interface()); err != nil {
return err
}
} else if nElem.Field(i).IsValid() && !nElem.Field(i).IsNil() {
oElem.Field(i).Set(nElem.Field(i))
} else if nElem.Field(i).IsValid() && !nField.IsNil() {
oField.Set(nField)
}

continue
}

if err := s.loadDiff(oElem.Field(i).Addr().Interface(), nElem.Field(i).Addr().Interface()); err != nil {
if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil {
return err
}
continue
}

// If the field is a struct, we need to recursively call LoadDiff
if oElem.Field(i).Kind() == reflect.Struct {
if err := s.loadDiff(oElem.Field(i).Addr().Interface(), nElem.Field(i).Addr().Interface()); err != nil {
if oField.Kind() == reflect.Struct {
if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil {
return err
}
continue
Expand All @@ -73,13 +76,15 @@ func (s *SQLPatch) loadDiff(old, newT any) error {
continue
}

patcherOptsTag := oElem.Type().Field(i).Tag.Get(TagOptsName)

// Compare the old and new fields.
//
// New fields take priority over old fields if they are provided based on the configuration.
if nElem.Field(i).Kind() != reflect.Ptr && (!nElem.Field(i).IsZero() || s.includeZeroValues) {
oElem.Field(i).Set(nElem.Field(i))
} else if nElem.Field(i).Kind() == reflect.Ptr && (!nElem.Field(i).IsNil() || s.includeNilValues) {
if nElem.Field(i).Kind() != reflect.Ptr && (!nField.IsZero() || s.shouldIncludeZero(patcherOptsTag)) {
oElem.Field(i).Set(nElem.Field(i))
} else if nElem.Field(i).Kind() == reflect.Ptr && (!nField.IsNil() || s.shouldIncludeNil(patcherOptsTag)) {
oField.Set(nElem.Field(i))
}
}

Expand Down
84 changes: 74 additions & 10 deletions loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,92 @@ func (s *loadDiffSuite) TestLoadDiff_Success() {
s.Equal(26, old.Age)
}

func (s *loadDiffSuite) TestLoadDiff_Success_Pointed_Fields() {
s.patch.includeNilValues = true

func (s *loadDiffSuite) TestLoadDiff_Success_StructOpt_IncludeNilField() {
type testStruct struct {
Name *string
Age *int
Name string
Age *int `patcher:"nil"`
}

old := testStruct{
Name: ptr("John"),
Name: "John",
Age: ptr(25),
}

n := testStruct{
Name: ptr("John Smith"),
Age: ptr(26),
Name: "John Smith",
Age: nil,
}

err := s.patch.loadDiff(&old, &n)
s.NoError(err)
s.Equal("John Smith", old.Name)
s.Nil(old.Age)
}

func (s *loadDiffSuite) TestLoadDiff_Success_StructOpt_IncludeZeroField() {
type testStruct struct {
Name string
Age int `patcher:"zero"`
}

old := testStruct{
Name: "John",
Age: 25,
}

n := testStruct{
Name: "John Smith",
Age: 0,
}

err := s.patch.loadDiff(&old, &n)
s.NoError(err)
s.Equal("John Smith", *old.Name)
s.Equal(26, *old.Age)
s.Equal("John Smith", old.Name)
s.Equal(0, old.Age)
}

func (s *loadDiffSuite) TestLoadDiff_Success_NoStructOpts() {
type testStruct struct {
Name string
Age int
}

old := testStruct{
Name: "John",
Age: 25,
}

n := testStruct{
Name: "John Smith",
Age: 0,
}

err := s.patch.loadDiff(&old, &n)
s.NoError(err)
s.Equal("John Smith", old.Name)
s.Equal(25, old.Age)
}

func (s *loadDiffSuite) TestLoadDiff_Success_Pointed_Fields() {
type testStruct struct {
Name string
Age int
}

old := testStruct{
Name: "John",
Age: 25,
}

n := testStruct{
Name: "John Smith",
Age: 0,
}

err := s.patch.loadDiff(&old, &n)
s.NoError(err)
s.Equal("John Smith", old.Name)
s.Equal(25, old.Age)
}

func (s *loadDiffSuite) TestLoadDiff_Success_ZeroValue() {
Expand Down
31 changes: 31 additions & 0 deletions patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"reflect"
"slices"
"strings"
)

Expand Down Expand Up @@ -132,3 +133,33 @@ func (s *SQLPatch) validateSQLGen() error {

return nil
}

func (s *SQLPatch) shouldIncludeNil(tag string) bool {
if s.includeNilValues {
return true
}

if tag != "" {
tags := strings.Split(tag, TagOptSeparator)
if slices.Contains(tags, TagOptAllowNil) {
return true
}
}

return false
}

func (s *SQLPatch) shouldIncludeZero(tag string) bool {
if s.includeZeroValues {
return true
}

if tag != "" {
tagOpts := strings.Split(tag, TagOptSeparator)
if slices.Contains(tagOpts, TagOptAllowZero) {
return true
}
}

return false
}
2 changes: 2 additions & 0 deletions patch_opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ const (
TagOptsName = "patcher"
TagOptSeparator = ","
TagOptSkip = "-"
TagOptAllowNil = "nil"
TagOptAllowZero = "zero"
)

type PatchOpt func(*SQLPatch)
Expand Down
15 changes: 9 additions & 6 deletions sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ func (s *SQLPatch) patchGen(resource any) {
continue
}

patcherOptsTag := fType.Tag.Get(TagOptsName)

// Skip fields that are to be ignored
if s.checkSkipField(fType) {
continue
} else if fVal.Kind() == reflect.Ptr && (fVal.IsNil() && !s.includeNilValues) {
} else if fVal.Kind() == reflect.Ptr && (fVal.IsNil() && !s.shouldIncludeNil(patcherOptsTag)) {
continue
} else if fVal.Kind() != reflect.Ptr && (fVal.IsZero() && !s.includeZeroValues) {
} else if fVal.Kind() != reflect.Ptr && (fVal.IsZero() && !s.shouldIncludeZero(patcherOptsTag)) {
continue
}

patcherOptsTag := fType.Tag.Get(TagOptsName)
if patcherOptsTag != "" {
patcherOpts := strings.Split(patcherOptsTag, TagOptSeparator)
if slices.Contains(patcherOpts, TagOptSkip) {
Expand All @@ -79,7 +80,7 @@ func (s *SQLPatch) patchGen(resource any) {
s.fields = append(s.fields, tag+" = ?")
}

if fVal.Kind() == reflect.Ptr && fVal.IsNil() && s.includeNilValues {
if fVal.Kind() == reflect.Ptr && fVal.IsNil() && s.shouldIncludeNil(patcherOptsTag) {
s.args = append(s.args, nil)
addField()
continue
Expand Down Expand Up @@ -215,9 +216,11 @@ func NewDiffSQLPatch[T any](old, newT *T, opts ...PatchOpt) (*SQLPatch, error) {
oldField := oldElem.Field(i)
copyField := oldCopyElem.Field(i)

if oldField.Kind() == reflect.Ptr && (oldField.IsNil() && copyField.IsNil() && !patch.includeZeroValues) {
patcherOptsTag := oldElem.Type().Field(i).Tag.Get(TagOptsName)

if oldField.Kind() == reflect.Ptr && (oldField.IsNil() && copyField.IsNil() && !patch.shouldIncludeNil(patcherOptsTag)) {
continue
} else if oldField.Kind() != reflect.Ptr && (oldField.IsZero() && copyField.IsZero() && !patch.includeZeroValues) {
} else if oldField.Kind() != reflect.Ptr && (oldField.IsZero() && copyField.IsZero() && !patch.shouldIncludeZero(patcherOptsTag)) {
continue
}

Expand Down
Loading

0 comments on commit 4cf97e3

Please sign in to comment.