-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloader.go
128 lines (105 loc) · 4.04 KB
/
loader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package patcher
import (
"errors"
"reflect"
"slices"
"strings"
)
var (
// ErrInvalidType is returned when the provided type is not a pointer to a struct
ErrInvalidType = errors.New("invalid type: must pointer to struct")
)
// LoadDiff inserts the fields from the new struct pointer into the old struct pointer, updating the old struct.
//
// Note that it only updates non-zero value fields, meaning you cannot set any field to zero, the empty string, etc.
// This behavior is configurable by setting the includeZeroValues option to true or for nil values by setting includeNilValues.
// Please see the LoaderOption's for more configuration options.
//
// This function is useful if you are inserting a patch into an existing object but require a new object to be returned with
// all fields updated.
func LoadDiff[T any](old, newT *T, opts ...PatchOpt) error {
return newPatchDefaults(opts...).loadDiff(old, newT)
}
// loadDiff inserts the fields provided in the new struct pointer into the old struct pointer and injects the new
// values into the old struct. It only pushes non-zero value updates, meaning you cannot set any field to zero,
// the empty string, etc. This is configurable by setting the includeZeroValues option to true or for nil values
// by setting includeNilValues. It handles embedded structs and recursively calls loadDiff for nested structs.
func (s *SQLPatch) loadDiff(old, newT any) error {
if !isPointerToStruct(old) || !isPointerToStruct(newT) {
return ErrInvalidType
}
oElem := reflect.ValueOf(old).Elem()
nElem := reflect.ValueOf(newT).Elem()
for i := 0; i < oElem.NumField(); i++ {
oField := oElem.Field(i)
nField := nElem.Field(i)
// Include only exported fields
if !oField.CanSet() || !nField.CanSet() {
continue
}
// Handle embedded structs (Anonymous fields)
if oElem.Type().Field(i).Anonymous {
// If the embedded field is a pointer, dereference it
if oField.Kind() == reflect.Ptr {
if !oField.IsNil() && !nField.IsNil() { // If both are not nil, we need to recursively call LoadDiff
if err := s.loadDiff(oField.Interface(), nField.Interface()); err != nil {
return err
}
} else if nElem.Field(i).IsValid() && !nField.IsNil() {
oField.Set(nField)
}
continue
}
if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil {
return err
}
continue
}
// If the field is a struct, we need to recursively call LoadDiff
if oField.Kind() == reflect.Struct {
if err := s.loadDiff(oField.Addr().Interface(), nField.Addr().Interface()); err != nil {
return err
}
continue
}
oldField := oElem.Type().Field(i)
// See if the field should be ignored.
if s.checkSkipField(&oldField) {
continue
}
patcherOptsTag := oldField.Tag.Get(TagOptsName)
// Compare the old and new fields.
//
// New fields take priority over old fields if they are provided based on the configuration.
if nElem.Field(i).Kind() != reflect.Ptr && (!nField.IsZero() || s.shouldIncludeZero(patcherOptsTag)) {
oElem.Field(i).Set(nElem.Field(i))
} else if nElem.Field(i).Kind() == reflect.Ptr && (!nField.IsNil() || s.shouldIncludeNil(patcherOptsTag)) {
oField.Set(nElem.Field(i))
}
}
return nil
}
func (s *SQLPatch) checkSkipField(field *reflect.StructField) bool {
// The ignore fields tag takes precedence over the ignore fields list
if s.checkSkipTag(field) {
return true
}
return s.ignoredFieldsCheck(field)
}
func (s *SQLPatch) checkSkipTag(field *reflect.StructField) bool {
val, ok := field.Tag.Lookup(TagOptsName)
if !ok {
return false
}
tags := strings.Split(val, TagOptSeparator)
return slices.Contains(tags, TagOptSkip)
}
func (s *SQLPatch) ignoredFieldsCheck(field *reflect.StructField) bool {
return s.checkIgnoredFields(field.Name) || s.checkIgnoreFunc(field)
}
func (s *SQLPatch) checkIgnoreFunc(field *reflect.StructField) bool {
return s.ignoreFieldsFunc != nil && s.ignoreFieldsFunc(field)
}
func (s *SQLPatch) checkIgnoredFields(field string) bool {
return len(s.ignoreFields) > 0 && slices.Contains(s.ignoreFields, field)
}