Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
if s, err := schema.ParseWithCaseInsensitivity(value, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
Expand Down
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,12 @@ go 1.18
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5
github.com/stretchr/testify v1.10.0
golang.org/x/text v0.20.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
37 changes: 22 additions & 15 deletions gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type Config struct {
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// CaseInsensitiveSchemaFields enabling case insensitivity for schema fields
CaseInsensitiveSchemaFields bool

// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
Expand Down Expand Up @@ -111,21 +113,22 @@ type DB struct {

// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
CaseInsensitiveSchemaFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
}

// Open initialize db session based on dialector
Expand Down Expand Up @@ -277,6 +280,10 @@ func (db *DB) Session(config *Session) *DB {
txConfig.PropagateUnscoped = true
}

if config.CaseInsensitiveSchemaFields {
txConfig.CaseInsensitiveSchemaFields = true
}

if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
Expand Down
2 changes: 1 addition & 1 deletion scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {

if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
sch, _ = schema.ParseWithCaseInsensitivity(db.Statement.Dest, db.cacheStore, db.NamingStrategy, db.CaseInsensitiveSchemaFields)
}

if len(columns) == 1 {
Expand Down
2 changes: 1 addition & 1 deletion schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {

cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
}

Expand Down
6 changes: 3 additions & 3 deletions schema/relationship.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
}
)

if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil {
if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil
}
Expand Down Expand Up @@ -360,8 +360,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`,
})

if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
if relation.JoinTable, err = ParseWithCaseInsensitivity(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer, schema.FieldsCaseInsensitive); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
Expand Down
60 changes: 44 additions & 16 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
FieldsCaseInsensitive bool
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
Expand Down Expand Up @@ -79,9 +80,24 @@
if field, ok := schema.FieldsByDBName[name]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByDBName {
if strings.EqualFold(key, name) {
return field
}
}
}
Comment on lines +83 to +89

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[PerformanceOptimization]

The case insensitive field lookup implementation may cause performance degradation when iterating over potentially large field maps. Consider implementing a more efficient solution using lowercase keys in maps for case insensitive comparisons.

Suggested change
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByDBName {
if strings.EqualFold(key, name) {
return field
}
}
}
// Consider adding a separate map for case insensitive lookups during schema initialization
// For example: fieldsByLowerDBName map[string]*Field
// This would provide O(1) lookups instead of O(n) iteration

Committable suggestion

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

if field, ok := schema.FieldsByName[name]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByName {
if strings.EqualFold(key, name) {
return field
}
}
}

return nil
}

Expand All @@ -99,6 +115,13 @@
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
if schema.FieldsCaseInsensitive {
for key, field := range schema.FieldsByBindName {
if strings.EqualFold(key, find) {
return field
}
}
}
}
return nil
}
Expand All @@ -121,11 +144,15 @@

// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
return ParseWithCaseInsensitivity(dest, cacheStore, namer, false)
}

func ParseWithCaseInsensitivity(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, caseInsensitive, "")
}

// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool, specialTableName string) (*Schema, error) {

Check failure on line 155 in schema/schema.go

View workflow job for this annotation

GitHub Actions / lint

calculated cyclomatic complexity for function ParseWithSpecialTableName is 77, max is 10 (cyclop)
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
Expand Down Expand Up @@ -182,18 +209,19 @@
}

schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
DBNames: make([]string, 0, 10),
Fields: make([]*Field, 0, 10),
FieldsByName: make(map[string]*Field, 10),
FieldsByBindName: make(map[string]*Field, 10),
FieldsByDBName: make(map[string]*Field, 10),
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
DBNames: make([]string, 0, 10),
Fields: make([]*Field, 0, 10),
FieldsByName: make(map[string]*Field, 10),
FieldsByBindName: make(map[string]*Field, 10),
FieldsByDBName: make(map[string]*Field, 10),
FieldsCaseInsensitive: caseInsensitive,
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
Expand Down Expand Up @@ -379,7 +407,7 @@
return schema, schema.err
}

func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer, caseInsensitive bool) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()

if modelType.Kind() != reflect.Struct {
Expand All @@ -399,5 +427,5 @@
return v.(*Schema), nil
}

return Parse(dest, cacheStore, namer)
return ParseWithCaseInsensitivity(dest, cacheStore, namer, caseInsensitive)
}
60 changes: 60 additions & 0 deletions schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
Expand Down Expand Up @@ -350,3 +351,62 @@ func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
}
}

func TestLookupField(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
Code string `gorm:"column:product_code"`
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
field := product.LookUpField("ProductID")
assert.NotNil(t, field)
field = product.LookUpField("productid")
assert.Nil(t, field)
field = product.LookUpField("product_code")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.Nil(t, field)

// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpField("productid")
assert.NotNil(t, field)
field = product.LookUpField("PRODUCT_CODE")
assert.NotNil(t, field)
}

func TestLookupFieldByBindName(t *testing.T) {
type Product struct {
ID uint `gorm:"primaryKey;autoIncrement"`
}
type Sellable struct {
Name string
Product Product `gorm:"embedded;embeddedPrefix:product_"`
}

product, err := schema.Parse(&Sellable{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse Sellable struct with composite primary key, got error %v", err)
}
field := product.LookUpFieldByBindName([]string{"Product", "ID"}, "ID")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.Nil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.Nil(t, field)

// Check case insensitivity
product.FieldsCaseInsensitive = true
field = product.LookUpFieldByBindName([]string{"Product", "ID"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"Product", "id"}, "id")
assert.NotNil(t, field)
field = product.LookUpFieldByBindName([]string{"product", "id"}, "id")
assert.NotNil(t, field)
}
4 changes: 2 additions & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@
reflectValue = reflectValue.Elem()
}

if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
if s, err := schema.ParseWithCaseInsensitivity(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields); err == nil {

Check failure on line 411 in statement.go

View workflow job for this annotation

GitHub Actions / lint

QF1008: could remove embedded field "DB" from selector (staticcheck)
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
Expand Down Expand Up @@ -510,7 +510,7 @@
}

func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.CaseInsensitiveSchemaFields, specialTableName); err == nil && stmt.Table == "" {

Check failure on line 513 in statement.go

View workflow job for this annotation

GitHub Actions / lint

QF1008: could remove embedded field "DB" from selector (staticcheck)
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
Expand Down
Loading