diff --git a/CHANGELOG.md b/CHANGELOG.md index 994697977..7a8229627 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added support of custom types to row.ScanStruct using sql.Scanner interface + ## v3.104.5 * Added query client session pool metrics: create_in_progress, in_use, waiters_queue * Added pool item closing for not-alived item diff --git a/internal/value/cast.go b/internal/value/cast.go index 7106b62fd..6d382a151 100644 --- a/internal/value/cast.go +++ b/internal/value/cast.go @@ -1,5 +1,12 @@ package value +import ( + "database/sql" + "database/sql/driver" + + "github.com/google/uuid" +) + func CastTo(v Value, dst interface{}) error { if dst == nil { return errNilDestination @@ -10,5 +17,20 @@ func CastTo(v Value, dst interface{}) error { return nil } + if _, ok := dst.(*uuid.UUID); ok { + return v.castTo(dst) + } + + if scanner, has := dst.(sql.Scanner); has { + dv := new(driver.Value) + + err := v.castTo(dv) + if err != nil { + return err + } + + return scanner.Scan(*dv) + } + return v.castTo(dst) } diff --git a/internal/value/cast_test.go b/internal/value/cast_test.go index 863587f59..00b519f3c 100644 --- a/internal/value/cast_test.go +++ b/internal/value/cast_test.go @@ -2,6 +2,7 @@ package value import ( "database/sql/driver" + "errors" "reflect" "testing" "time" @@ -32,6 +33,19 @@ func loadLocation(t *testing.T, name string) *time.Location { return loc } +type testStringSQLScanner string + +func (s *testStringSQLScanner) Scan(value any) error { + ts, ok := value.(string) + if !ok { + return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string") + } + + *s = testStringSQLScanner(ts) + + return nil +} + func TestCastTo(t *testing.T) { testsCases := []struct { name string @@ -428,6 +442,13 @@ func TestCastTo(t *testing.T) { exp: DateValueFromTime(time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)), err: nil, }, + { + name: xtest.CurrentFileLine(), + value: TextValue("text-string"), + dst: ptr[testStringSQLScanner](), + exp: testStringSQLScanner("text-string"), + err: nil, + }, } for _, tt := range testsCases { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/value/value.go b/internal/value/value.go index 03d0adab8..870f81025 100644 --- a/internal/value/value.go +++ b/internal/value/value.go @@ -1515,7 +1515,7 @@ func (v *listValue) castTo(dst any) error { inner.Set(newSlice) for i, item := range v.ListItems() { - if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil { + if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "%w '%s(%+v)' to '%T' destination", ErrCannotCast, v.Type().Yql(), v, dstValue, @@ -1649,7 +1649,7 @@ func (v *setValue) castTo(dst any) error { inner.Set(newSlice) for i, item := range v.items { - if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil { + if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "%w '%s(%+v)' to '%T' destination", ErrCannotCast, v.Type().Yql(), v, dstValue, @@ -1757,7 +1757,7 @@ func (v *optionalValue) castTo(dst any) error { return nil } - if err := v.value.castTo(ptr.Interface()); err != nil { + if err := CastTo(v.value, (ptr.Interface())); err != nil { return xerrors.WithStackTrace(err) } @@ -1772,7 +1772,7 @@ func (v *optionalValue) castTo(dst any) error { inner.Set(reflect.New(inner.Type().Elem())) - if err := v.value.castTo(inner.Interface()); err != nil { + if err := CastTo(v.value, inner.Interface()); err != nil { return xerrors.WithStackTrace(err) } @@ -1853,7 +1853,7 @@ func (v *structValue) castTo(dst any) error { } for i, field := range v.fields { - if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil { + if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "scan error on struct field name '%s': %w", field.Name, err, @@ -2031,7 +2031,7 @@ func (v *tupleValue) TupleItems() []Value { func (v *tupleValue) castTo(dst any) error { if len(v.items) == 1 { - return v.items[0].castTo(dst) + return CastTo(v.items[0], dst) } switch dstValue := dst.(type) {