Skip to content

Commit

Permalink
chore(simplification): Simplifying the patches (#111)
Browse files Browse the repository at this point in the history
* simplifying the patch gen

* correcting the validation for what is a valid type and not

* code simplification 2

* patch gen simple

* loader simple
  • Loading branch information
Jacobbrewer1 authored Feb 10, 2025
1 parent 0c9719f commit 0337e43
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 269 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci-code-approval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ jobs:
uses: golangci/golangci-lint-action@v6
with:
version: latest
only-new-issues: true
args: --verbose --timeout 5m

code-approval:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ you.
* You can pass a struct that implements the `WhereTyper` interface to use `OR` in the where clause. Patcher will
default to `AND` if the `WhereTyper` interface is not implemented.
* `WithJoin(joinClause Joiner)`: Add join clauses to the SQL query.
* `includeZeroValues`: Set to true to include zero values in the diff. (Only for NewDiffSQLPatch)
* `includeNilValues`: Set to true to include nil values in the diff. (Only for NewDiffSQLPatch)
* `includeZeroValues`: Set to true to include zero values in the Patch.
* `includeNilValues`: Set to true to include nil values in the Patch.

### Basic Examples

Expand Down
66 changes: 17 additions & 49 deletions loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package patcher
import (
"errors"
"reflect"
"slices"
"strings"
)

var (
Expand Down Expand Up @@ -36,7 +34,7 @@ func (s *SQLPatch) loadDiff(old, newT any) error {
oElem := reflect.ValueOf(old).Elem()
nElem := reflect.ValueOf(newT).Elem()

for i := 0; i < oElem.NumField(); i++ {
for i := range oElem.NumField() {
oField := oElem.Field(i)
nField := nElem.Field(i)

Expand All @@ -45,22 +43,11 @@ func (s *SQLPatch) loadDiff(old, newT any) error {
continue
}

// Handle embedded structs (Anonymous fields)
if oElem.Type().Field(i).Anonymous {
// If the embedded field is a pointer, dereference it
if oField.Kind() == reflect.Ptr {
if !oField.IsNil() && !nField.IsNil() { // If both are not nil, we need to recursively call LoadDiff
if err := s.loadDiff(oField.Interface(), nField.Interface()); err != nil {
return err
}
} else if nElem.Field(i).IsValid() && !nField.IsNil() {
oField.Set(nField)
}

continue
}
oldField := oElem.Type().Field(i)

if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil {
// Handle embedded structs (Anonymous fields)
if oldField.Anonymous {
if err := s.handleEmbeddedStruct(oField, nField); err != nil {
return err
}
continue
Expand All @@ -74,8 +61,6 @@ func (s *SQLPatch) loadDiff(old, newT any) error {
continue
}

oldField := oElem.Type().Field(i)

// See if the field should be ignored.
if s.checkSkipField(&oldField) {
continue
Expand All @@ -86,43 +71,26 @@ func (s *SQLPatch) loadDiff(old, newT any) error {
// Compare the old and new fields.
//
// New fields take priority over old fields if they are provided based on the configuration.
if nElem.Field(i).Kind() != reflect.Ptr && (!nField.IsZero() || s.shouldIncludeZero(patcherOptsTag)) {
oElem.Field(i).Set(nElem.Field(i))
} else if nElem.Field(i).Kind() == reflect.Ptr && (!nField.IsNil() || s.shouldIncludeNil(patcherOptsTag)) {
oField.Set(nElem.Field(i))
if nField.Kind() != reflect.Ptr && (!nField.IsZero() || s.shouldIncludeZero(patcherOptsTag)) {
oField.Set(nField)
} else if nField.Kind() == reflect.Ptr && (!nField.IsNil() || s.shouldIncludeNil(patcherOptsTag)) {
oField.Set(nField)
}
}

return nil
}

func (s *SQLPatch) checkSkipField(field *reflect.StructField) bool {
// The ignore fields tag takes precedence over the ignore fields list
if s.checkSkipTag(field) {
return true
func (s *SQLPatch) handleEmbeddedStruct(oField, nField reflect.Value) error {
if oField.Kind() != reflect.Ptr {
return s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface())
}

return s.ignoredFieldsCheck(field)
}

func (s *SQLPatch) checkSkipTag(field *reflect.StructField) bool {
val, ok := field.Tag.Lookup(TagOptsName)
if !ok {
return false
if !oField.IsNil() && !nField.IsNil() {
return s.loadDiff(oField.Interface(), nField.Interface())
} else if nField.IsValid() && !nField.IsNil() {
oField.Set(nField)
}

tags := strings.Split(val, TagOptSeparator)
return slices.Contains(tags, TagOptSkip)
}

func (s *SQLPatch) ignoredFieldsCheck(field *reflect.StructField) bool {
return s.checkIgnoredFields(field.Name) || s.checkIgnoreFunc(field)
}

func (s *SQLPatch) checkIgnoreFunc(field *reflect.StructField) bool {
return s.ignoreFieldsFunc != nil && s.ignoreFieldsFunc(field)
}

func (s *SQLPatch) checkIgnoredFields(field string) bool {
return len(s.ignoreFields) > 0 && slices.Contains(s.ignoreFields, field)
return nil
}
52 changes: 52 additions & 0 deletions patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,55 @@ func (s *SQLPatch) shouldOmitEmpty(tag string) bool {

return false
}

func (s *SQLPatch) shouldSkipField(fType *reflect.StructField, fVal reflect.Value) bool {
if !fType.IsExported() || !isValidType(fVal) || s.checkSkipField(fType) {
return true
}

patcherOptsTag := fType.Tag.Get(TagOptsName)
if fVal.Kind() == reflect.Ptr && (fVal.IsNil() && !s.shouldIncludeNil(patcherOptsTag)) {
return true
}
if fVal.Kind() != reflect.Ptr && (fVal.IsZero() && !s.shouldIncludeZero(patcherOptsTag)) {
return true
}
if patcherOptsTag != "" {
patcherOpts := strings.Split(patcherOptsTag, TagOptSeparator)
if slices.Contains(patcherOpts, TagOptSkip) {
return true
}
}
return false
}

func (s *SQLPatch) checkSkipField(field *reflect.StructField) bool {
// The ignore fields tag takes precedence over the ignore fields list
if s.checkSkipTag(field) {
return true
}

return s.ignoredFieldsCheck(field)
}

func (s *SQLPatch) checkSkipTag(field *reflect.StructField) bool {
val, ok := field.Tag.Lookup(TagOptsName)
if !ok {
return false
}

tags := strings.Split(val, TagOptSeparator)
return slices.Contains(tags, TagOptSkip)
}

func (s *SQLPatch) ignoredFieldsCheck(field *reflect.StructField) bool {
return s.checkIgnoredFields(field.Name) || s.checkIgnoreFunc(field)
}

func (s *SQLPatch) checkIgnoreFunc(field *reflect.StructField) bool {
return s.ignoreFieldsFunc != nil && s.ignoreFieldsFunc(field)
}

func (s *SQLPatch) checkIgnoredFields(field string) bool {
return len(s.ignoreFields) > 0 && slices.Contains(s.ignoreFields, field)
}
93 changes: 13 additions & 80 deletions sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"strings"
)

Expand All @@ -32,16 +31,8 @@ func NewSQLPatch(resource any, opts ...PatchOpt) *SQLPatch {
// It processes the fields of the struct, applying the necessary tags and options,
// and prepares the SQL update statement components (fields and arguments).
func (s *SQLPatch) patchGen(resource any) {
// If the resource is a pointer, we need to dereference it to get the value
if reflect.TypeOf(resource).Kind() == reflect.Ptr {
resource = reflect.ValueOf(resource).Elem().Interface()
}

// Ensure that the resource is a struct
if reflect.TypeOf(resource).Kind() != reflect.Struct {
// This is intentionally a panic as this is a programming error and should be fixed by the developer
panic("resource is not a struct")
}
resource = dereferenceIfPointer(resource)
ensureStruct(resource)

rType := reflect.TypeOf(resource)
rVal := reflect.ValueOf(resource)
Expand All @@ -50,85 +41,27 @@ func (s *SQLPatch) patchGen(resource any) {
s.fields = make([]string, 0, n)
s.args = make([]any, 0, n)

for i := 0; i < n; i++ {
for i := range n {
fType := rType.Field(i)
fVal := rVal.Field(i)
tag := fType.Tag.Get(s.tagName)

// Skip unexported fields
if !fType.IsExported() {
continue
}
tag := getTag(&fType, s.tagName)
optsTag := fType.Tag.Get(TagOptsName)

tags := strings.Split(tag, TagOptSeparator)
if len(tags) > 1 {
tag = tags[0]
}

patcherOptsTag := fType.Tag.Get(TagOptsName)

// Skip fields that are to be ignored
switch {
case s.checkSkipField(&fType):
continue
case fVal.Kind() == reflect.Ptr && (fVal.IsNil() && !s.shouldIncludeNil(patcherOptsTag)):
continue
case fVal.Kind() != reflect.Ptr && (fVal.IsZero() && !s.shouldIncludeZero(patcherOptsTag)):
if s.shouldSkipField(&fType, fVal) {
continue
}

if patcherOptsTag != "" {
patcherOpts := strings.Split(patcherOptsTag, TagOptSeparator)
if slices.Contains(patcherOpts, TagOptSkip) {
var arg any = nil
if fVal.Kind() == reflect.Ptr && fVal.IsNil() {
if !s.shouldIncludeNil(optsTag) {
continue
}
} else {
arg = getValue(fVal)
}

// If no tag is set, use the field name
if tag == "" {
tag = fType.Name
}

addField := func() {
s.fields = append(s.fields, tag+" = ?")
}

if fVal.Kind() == reflect.Ptr && fVal.IsNil() && s.shouldIncludeNil(patcherOptsTag) {
s.args = append(s.args, nil)
addField()
continue
} else if fVal.Kind() == reflect.Ptr && fVal.IsNil() {
continue
}

addField()

val := fVal
if fVal.Kind() == reflect.Ptr {
val = fVal.Elem()
}

switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s.args = append(s.args, val.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
s.args = append(s.args, val.Uint())
case reflect.Float32, reflect.Float64:
s.args = append(s.args, val.Float())
case reflect.Complex64, reflect.Complex128:
s.args = append(s.args, val.Complex())
case reflect.String:
s.args = append(s.args, val.String())
case reflect.Bool:
boolArg := 0
if val.Bool() {
boolArg = 1
}
s.args = append(s.args, boolArg)
default:
// This is intentionally a panic as this is a programming error and should be fixed by the developer
panic(fmt.Sprintf("unsupported type: %s", val.Kind()))
}
s.fields = append(s.fields, tag+" = ?")
s.args = append(s.args, arg)
}
}

Expand Down
Loading

0 comments on commit 0337e43

Please sign in to comment.