Skip to content

Commit cc63007

Browse files
committed
rebase & support sql.Scanner iface
1 parent 48ccba7 commit cc63007

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed

internal/value/cast.go

+22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
package value
22

3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
7+
"github.com/google/uuid"
8+
)
9+
310
func CastTo(v Value, dst interface{}) error {
411
if dst == nil {
512
return errNilDestination
@@ -10,6 +17,21 @@ func CastTo(v Value, dst interface{}) error {
1017
return nil
1118
}
1219

20+
if _, ok := dst.(*uuid.UUID); ok {
21+
return v.castTo(dst)
22+
}
23+
24+
if scanner, has := dst.(sql.Scanner); has {
25+
dv := new(driver.Value)
26+
27+
err := v.castTo(dv)
28+
if err != nil {
29+
return err
30+
}
31+
32+
return scanner.Scan(*dv)
33+
}
34+
1335
if scanner, has := dst.(Scanner); has {
1436
return scanner.UnmarshalYDBValue(v)
1537
}

internal/value/cast_test.go

+33-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package value
22

33
import (
44
"database/sql/driver"
5+
"errors"
56
"reflect"
67
"testing"
78
"time"
@@ -32,12 +33,32 @@ func loadLocation(t *testing.T, name string) *time.Location {
3233
return loc
3334
}
3435

35-
type testStringValueScanner struct {
36-
field string
37-
}
36+
type testStringValueScanner string
3837

3938
func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error {
40-
return CastTo(v, &s.field)
39+
var tmp string
40+
41+
err := CastTo(v, &tmp)
42+
if err != nil {
43+
return err
44+
}
45+
46+
*s = testStringValueScanner(tmp)
47+
48+
return nil
49+
}
50+
51+
type testStringSQLScanner string
52+
53+
func (s *testStringSQLScanner) Scan(value any) error {
54+
ts, ok := value.(string)
55+
if !ok {
56+
return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string")
57+
}
58+
59+
*s = testStringSQLScanner(ts)
60+
61+
return nil
4162
}
4263

4364
func TestCastTo(t *testing.T) {
@@ -440,7 +461,14 @@ func TestCastTo(t *testing.T) {
440461
name: xtest.CurrentFileLine(),
441462
value: TextValue("text-string"),
442463
dst: ptr[testStringValueScanner](),
443-
exp: testStringValueScanner{field: "text-string"},
464+
exp: testStringValueScanner("text-string"),
465+
err: nil,
466+
},
467+
{
468+
name: xtest.CurrentFileLine(),
469+
value: TextValue("text-string"),
470+
dst: ptr[testStringSQLScanner](),
471+
exp: testStringSQLScanner("text-string"),
444472
err: nil,
445473
},
446474
}

internal/value/value.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ func (v *listValue) castTo(dst any) error {
13031303
inner.Set(newSlice)
13041304

13051305
for i, item := range v.ListItems() {
1306-
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1306+
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
13071307
return xerrors.WithStackTrace(fmt.Errorf(
13081308
"%w '%s(%+v)' to '%T' destination",
13091309
ErrCannotCast, v.Type().Yql(), v, dstValue,
@@ -1437,7 +1437,7 @@ func (v *setValue) castTo(dst any) error {
14371437
inner.Set(newSlice)
14381438

14391439
for i, item := range v.items {
1440-
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1440+
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
14411441
return xerrors.WithStackTrace(fmt.Errorf(
14421442
"%w '%s(%+v)' to '%T' destination",
14431443
ErrCannotCast, v.Type().Yql(), v, dstValue,
@@ -1545,7 +1545,7 @@ func (v *optionalValue) castTo(dst any) error {
15451545
return nil
15461546
}
15471547

1548-
if err := v.value.castTo(ptr.Interface()); err != nil {
1548+
if err := CastTo(v.value, (ptr.Interface())); err != nil {
15491549
return xerrors.WithStackTrace(err)
15501550
}
15511551

@@ -1560,7 +1560,7 @@ func (v *optionalValue) castTo(dst any) error {
15601560

15611561
inner.Set(reflect.New(inner.Type().Elem()))
15621562

1563-
if err := v.value.castTo(inner.Interface()); err != nil {
1563+
if err := CastTo(v.value, inner.Interface()); err != nil {
15641564
return xerrors.WithStackTrace(err)
15651565
}
15661566

@@ -1641,7 +1641,7 @@ func (v *structValue) castTo(dst any) error {
16411641
}
16421642

16431643
for i, field := range v.fields {
1644-
if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil {
1644+
if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil {
16451645
return xerrors.WithStackTrace(fmt.Errorf(
16461646
"scan error on struct field name '%s': %w",
16471647
field.Name, err,
@@ -1768,7 +1768,7 @@ func (v *tupleValue) TupleItems() []Value {
17681768

17691769
func (v *tupleValue) castTo(dst any) error {
17701770
if len(v.items) == 1 {
1771-
return v.items[0].castTo(dst)
1771+
return CastTo(v.items[0], dst)
17721772
}
17731773

17741774
switch dstValue := dst.(type) {

0 commit comments

Comments
 (0)