Skip to content

Commit 5c7eae8

Browse files
committed
rebase & support sql.Scanner iface
1 parent 220ddf5 commit 5c7eae8

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

internal/value/cast.go

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
package value
22

3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
7+
"github.com/google/uuid"
8+
)
9+
310
func CastTo(v Value, dst interface{}) error {
411
if dst == nil {
512
return errNilDestination
@@ -10,13 +17,20 @@ func CastTo(v Value, dst interface{}) error {
1017
return nil
1118
}
1219

13-
if scanner, has := dst.(Scanner); has {
14-
return scanner.UnmarshalYDBValue(v)
20+
if _, ok := dst.(*uuid.UUID); ok {
21+
return v.castTo(dst)
1522
}
1623

17-
return v.castTo(dst)
18-
}
24+
if scanner, has := dst.(sql.Scanner); has {
25+
dv := new(driver.Value)
1926

20-
type Scanner interface {
21-
UnmarshalYDBValue(value Value) error
27+
err := v.castTo(dv)
28+
if err != nil {
29+
return err
30+
}
31+
32+
return scanner.Scan(*dv)
33+
}
34+
35+
return v.castTo(dst)
2236
}

internal/value/cast_test.go

+13-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package value
22

33
import (
44
"database/sql/driver"
5+
"errors"
56
"reflect"
67
"testing"
78
"time"
@@ -32,12 +33,17 @@ func loadLocation(t *testing.T, name string) *time.Location {
3233
return loc
3334
}
3435

35-
type testStringValueScanner struct {
36-
field string
37-
}
36+
type testStringSQLScanner string
37+
38+
func (s *testStringSQLScanner) Scan(value any) error {
39+
ts, ok := value.(string)
40+
if !ok {
41+
return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string")
42+
}
43+
44+
*s = testStringSQLScanner(ts)
3845

39-
func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error {
40-
return CastTo(v, &s.field)
46+
return nil
4147
}
4248

4349
func TestCastTo(t *testing.T) {
@@ -439,8 +445,8 @@ func TestCastTo(t *testing.T) {
439445
{
440446
name: xtest.CurrentFileLine(),
441447
value: TextValue("text-string"),
442-
dst: ptr[testStringValueScanner](),
443-
exp: testStringValueScanner{field: "text-string"},
448+
dst: ptr[testStringSQLScanner](),
449+
exp: testStringSQLScanner("text-string"),
444450
err: nil,
445451
},
446452
}

internal/value/value.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ func (v *listValue) castTo(dst any) error {
13031303
inner.Set(newSlice)
13041304

13051305
for i, item := range v.ListItems() {
1306-
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1306+
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
13071307
return xerrors.WithStackTrace(fmt.Errorf(
13081308
"%w '%s(%+v)' to '%T' destination",
13091309
ErrCannotCast, v.Type().Yql(), v, dstValue,
@@ -1437,7 +1437,7 @@ func (v *setValue) castTo(dst any) error {
14371437
inner.Set(newSlice)
14381438

14391439
for i, item := range v.items {
1440-
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1440+
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
14411441
return xerrors.WithStackTrace(fmt.Errorf(
14421442
"%w '%s(%+v)' to '%T' destination",
14431443
ErrCannotCast, v.Type().Yql(), v, dstValue,
@@ -1545,7 +1545,7 @@ func (v *optionalValue) castTo(dst any) error {
15451545
return nil
15461546
}
15471547

1548-
if err := v.value.castTo(ptr.Interface()); err != nil {
1548+
if err := CastTo(v.value, (ptr.Interface())); err != nil {
15491549
return xerrors.WithStackTrace(err)
15501550
}
15511551

@@ -1560,7 +1560,7 @@ func (v *optionalValue) castTo(dst any) error {
15601560

15611561
inner.Set(reflect.New(inner.Type().Elem()))
15621562

1563-
if err := v.value.castTo(inner.Interface()); err != nil {
1563+
if err := CastTo(v.value, inner.Interface()); err != nil {
15641564
return xerrors.WithStackTrace(err)
15651565
}
15661566

@@ -1641,7 +1641,7 @@ func (v *structValue) castTo(dst any) error {
16411641
}
16421642

16431643
for i, field := range v.fields {
1644-
if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil {
1644+
if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil {
16451645
return xerrors.WithStackTrace(fmt.Errorf(
16461646
"scan error on struct field name '%s': %w",
16471647
field.Name, err,
@@ -1768,7 +1768,7 @@ func (v *tupleValue) TupleItems() []Value {
17681768

17691769
func (v *tupleValue) castTo(dst any) error {
17701770
if len(v.items) == 1 {
1771-
return v.items[0].castTo(dst)
1771+
return CastTo(v.items[0], dst)
17721772
}
17731773

17741774
switch dstValue := dst.(type) {

0 commit comments

Comments
 (0)