From 0f24429cfa23d2f7e124ed5c2a137bb963b2d2c5 Mon Sep 17 00:00:00 2001 From: deen Date: Wed, 10 Mar 2021 22:08:54 +0800 Subject: [PATCH] try to use sql.Scanner prior to handleConvert --- scanner/scanner.go | 9 +- scanner/scanner_test.go | 312 ++++++++++++++++++++++------------------ 2 files changed, 174 insertions(+), 147 deletions(-) diff --git a/scanner/scanner.go b/scanner/scanner.go index cccd165..c270617 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -360,16 +360,17 @@ func convert(mapValue interface{}, valuei reflect.Value, wrapErr convertErrWrapp valuei.Set(reflect.ValueOf(mapValue)) return nil } + + if scanner, ok := valuei.Addr().Interface().(sql.Scanner); ok { + return scanner.Scan(mapValue) + } + //time.Time to string switch assertT := mapValue.(type) { case time.Time: return handleConvertTime(assertT, mvt, vit, &valuei, wrapErr) } - if scanner, ok := valuei.Addr().Interface().(sql.Scanner); ok { - return scanner.Scan(mapValue) - } - //according to go-mysql-driver/mysql, driver.Value type can only be: //int64 or []byte(> maxInt64) //float32/float64 diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 2529513..0ced576 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -5,13 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/stretchr/testify/require" "math" "reflect" "testing" "time" "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" ) func TestBindOne(t *testing.T) { @@ -27,10 +27,40 @@ func TestBindOne(t *testing.T) { "ag": age, } err := bind(mp, &p) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(name, p.Name) - ass.Equal(age, p.Age) + should := require.New(t) + should.NoError(err) + should.Equal(name, p.Name) + should.Equal(age, p.Age) +} + +func TestScanner_Time(t *testing.T) { + type Person struct { + Data myData `ddb:"qwe"` + } + var p Person + now := time.Now() + var mp = map[string]interface{}{ + "qwe": now, + } + err := bind(mp, &p) + should := require.New(t) + should.NoError(err) + should.Equal(now, p.Data.d) + mp["qwe"] = 10 + err = bind(mp, &p) + should.EqualError(err, "not time.Time type") +} + +type myData struct { + d time.Time +} + +func (m *myData) Scan(src interface{}) error { + if v, ok := src.(time.Time); ok { + m.d = v + return nil + } + return errors.New("not time.Time type") } func TestBindOne_byte_string(t *testing.T) { @@ -46,10 +76,10 @@ func TestBindOne_byte_string(t *testing.T) { "ag": age, } err := bind(mp, &p) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(string(name), p.Name) - ass.Equal(age, p.Age) + should := require.New(t) + should.NoError(err) + should.Equal(string(name), p.Name) + should.Equal(age, p.Age) } func TestBindOne_byte_uint8(t *testing.T) { @@ -65,10 +95,10 @@ func TestBindOne_byte_uint8(t *testing.T) { "ag": age, } err := bind(mp, &p) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(name, p.Name) - ass.Equal(age, p.Age) + should := require.New(t) + should.NoError(err) + should.Equal(name, p.Name) + should.Equal(age, p.Age) } func TestBindOne_byte_uint8_pointer(t *testing.T) { @@ -84,10 +114,10 @@ func TestBindOne_byte_uint8_pointer(t *testing.T) { "ag": age, } err := bind(mp, p) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(name, p.Name) - ass.Equal(age, p.Age) + should := require.New(t) + should.NoError(err) + should.Equal(name, p.Name) + should.Equal(age, p.Age) } func TestBindOne_uint8_byte(t *testing.T) { @@ -103,10 +133,10 @@ func TestBindOne_uint8_byte(t *testing.T) { "ag": age, } err := bind(mp, &p) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(name, p.Name) - ass.Equal(age, p.Age) + should := require.New(t) + should.NoError(err) + should.Equal(name, p.Name) + should.Equal(age, p.Age) } func TestBindOne_float(t *testing.T) { @@ -119,9 +149,9 @@ func TestBindOne_float(t *testing.T) { "sl": salary, } err := bind(mp, &p) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(salary, p.Salary) + should := require.New(t) + should.NoError(err) + should.Equal(salary, p.Salary) } func TestBindSlice(t *testing.T) { @@ -135,11 +165,11 @@ func TestBindSlice(t *testing.T) { data = append(data, map[string]interface{}{"age": v}) } err := bindSlice(data, &students) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(len(testCases), len(students)) + should := require.New(t) + should.NoError(err) + should.Equal(len(testCases), len(students)) for idx, p := range students { - ass.Equal(testCases[idx], p.Age) + should.Equal(testCases[idx], p.Age) } } func Test_Scan_PointerArr(t *testing.T) { @@ -164,12 +194,12 @@ func Test_Scan_PointerArr(t *testing.T) { }, ) err := bindSlice(data, &stus) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(len(data), len(stus)) + should := require.New(t) + should.NoError(err) + should.Equal(len(data), len(stus)) for i := 0; i < len(stus); i++ { - ass.Equal(data[i]["name"], stus[i].Name, "bind pointer name") - ass.Equal(data[i]["sala"], stus[i].Salary, "bind pointer sala") + should.Equal(data[i]["name"], stus[i].Name, "bind pointer name") + should.Equal(data[i]["sala"], stus[i].Salary, "bind pointer sala") } } @@ -181,9 +211,9 @@ func Test_Bind_Float32_2_Float64(t *testing.T) { err := bind(map[string]interface{}{ "num": float32(10.5), }, &a) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(float64(10.5), a.Num) + should := require.New(t) + should.NoError(err) + should.Equal(float64(10.5), a.Num) } func Test_Bind_Float64_2_Float32(t *testing.T) { @@ -194,9 +224,9 @@ func Test_Bind_Float64_2_Float32(t *testing.T) { err := bind(map[string]interface{}{ "num": float64(10.5), }, &a) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(float32(10.5), a.Num) + should := require.New(t) + should.NoError(err) + should.Equal(float32(10.5), a.Num) } func Test_Bind_int64_2_uint64(t *testing.T) { @@ -209,10 +239,10 @@ func Test_Bind_int64_2_uint64(t *testing.T) { "num": int64(10), "age": int64(20), }, &a) - ass := assert.New(t) - ass.NoError(err, `shouldn't be error when bind int64 to uint64`) - ass.Equal(uint64(10), a.Num) - ass.Equal(uint8(20), a.Age) + should := require.New(t) + should.NoError(err, `shouldn't be error when bind int64 to uint64`) + should.Equal(uint64(10), a.Num) + should.Equal(uint8(20), a.Age) } func Test_Ignore_Unexported_Field(t *testing.T) { @@ -226,10 +256,10 @@ func Test_Ignore_Unexported_Field(t *testing.T) { "age": int64(100), } err := bind(data, &Tom) - ass := assert.New(t) - ass.NoError(err) - ass.Equal(0, Tom.age) - ass.Equal("Tommmm", Tom.Name) + should := require.New(t) + should.NoError(err) + should.Equal(0, Tom.age) + should.Equal("Tommmm", Tom.Name) } func Test_Bind_Time_2_String(t *testing.T) { @@ -241,16 +271,16 @@ func Test_Bind_Time_2_String(t *testing.T) { "create_time": now, } var tObj Whatever - ass := assert.New(t) + should := require.New(t) err := bind(data, &tObj) - ass.NoError(err, "time.Time should transform to string and bind to string type") - ass.Equal(now.Format("2006-01-02 15:04:05"), tObj.When) + should.NoError(err, "time.Time should transform to string and bind to string type") + should.Equal(now.Format("2006-01-02 15:04:05"), tObj.When) type WillErr struct { When int `ddb:"create_time"` } var some WillErr err = bind(data, &some) - ass.Error(err, "time.Time could only bind to time.Time&string type %v", some) + should.Error(err, "time.Time could only bind to time.Time&string type %v", some) } func Test_Bind_Slice_2_Time(t *testing.T) { @@ -262,10 +292,10 @@ func Test_Bind_Slice_2_Time(t *testing.T) { "create_time": []uint8(now.Format(cTimeFormat)), } var tObj Whatever - ass := assert.New(t) + should := require.New(t) err := bind(data, &tObj) - ass.NoError(err, "[]uint8 should try to cast to time.Time") - ass.Equal(now.Unix(), tObj.When.Unix()) + should.NoError(err, "[]uint8 should try to cast to time.Time") + should.Equal(now.Unix(), tObj.When.Unix()) } func Test_ScanMap(t *testing.T) { @@ -313,18 +343,18 @@ func Test_ScanMap(t *testing.T) { }, }, } - ass := assert.New(t) + should := require.New(t) db, mock, err := sqlmock.New() - ass.NoError(err) + should.NoError(err) for _, tc := range testData { mock.ExpectQuery("select \\* from tb").WillReturnRows(tc.rows) rows, err := db.Query("select * from tb") - ass.NoError(err) - ass.NotNil(rows) - ass.NoError(mock.ExpectationsWereMet()) + should.NoError(err) + should.NotNil(rows) + should.NoError(mock.ExpectationsWereMet()) mpArr, err := ScanMap(rows) - ass.NoError(err) - ass.Equal(tc.out, mpArr) + should.NoError(err) + should.Equal(tc.out, mpArr) } } @@ -364,18 +394,18 @@ func Test_Slice_2_Int(t *testing.T) { }, } var u user - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { mp := map[string]interface{}{ "age": tc.in, } err := bind(mp, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u.Age) + should.Equal(tc.out, u.Age) } } @@ -414,18 +444,18 @@ func Test_Scan_Pointer(t *testing.T) { }, } var u user - ass := assert.New(t) + should := require.New(t) for idx, tc := range testData { mp := map[string]interface{}{ "age": tc.in, } err := bind(mp, &u) if err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, *u.Age, "case #%d fail", idx) + should.Equal(tc.out, *u.Age, "case #%d fail", idx) } } @@ -488,16 +518,16 @@ func Test_Scan_Multi_Pointer(t *testing.T) { }, }, } - ass := assert.New(t) + should := require.New(t) for idx, tc := range testData { var u user err := bind(tc.in, &u) if err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u, "case #%d fail %+v", idx, u) + should.Equal(tc.out, u, "case #%d fail %+v", idx, u) } } @@ -552,18 +582,18 @@ func Test_Slice_2_UInt(t *testing.T) { }, } var u user - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { mp := map[string]interface{}{ "age": tc.in, } err := bind(mp, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u.Age) + should.Equal(tc.out, u.Age) } } @@ -608,18 +638,18 @@ func Test_Slice_2_Float(t *testing.T) { }, } var u user - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { mp := map[string]interface{}{ "score": tc.in, } err := bind(mp, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.True(math.Abs(tc.out-u.Score) < 1e5) + should.True(math.Abs(tc.out-u.Score) < 1e5) } } @@ -674,16 +704,16 @@ func Test_int64_2_bool(t *testing.T) { }, }, } - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { var u user err := bind(tc.in, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u) + should.Equal(tc.out, u) } } @@ -758,16 +788,16 @@ func Test_int64_2_string(t *testing.T) { }, }, } - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { var u user err := bind(tc.in, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u) + should.Equal(tc.out, u) } } @@ -809,16 +839,16 @@ func Test_uint8_2_any(t *testing.T) { err: nil, }, } - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { var u user err := bind(tc.in, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u) + should.Equal(tc.out, u) } } @@ -854,7 +884,7 @@ func Test_sql_scanner(t *testing.T) { err: nil, }, } - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { var u user mp := map[string]interface{}{ @@ -862,11 +892,11 @@ func Test_sql_scanner(t *testing.T) { } err := bind(mp, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u.Name) + should.Equal(tc.out, u.Name) } } @@ -902,7 +932,7 @@ func Test_sql_scanner_with_pointer(t *testing.T) { err: nil, }, } - ass := assert.New(t) + should := require.New(t) for _, tc := range testData { var u user mp := map[string]interface{}{ @@ -910,21 +940,21 @@ func Test_sql_scanner_with_pointer(t *testing.T) { } err := bind(mp, &u) if tc.err == nil { - ass.NoError(err) + should.NoError(err) } else { - ass.Error(err) + should.Error(err) } - ass.Equal(tc.out, u.Name) + should.Equal(tc.out, u.Name) } } func TestTagSetOnlyOnce(t *testing.T) { userDefinedTagName = "a" SetTagName("foo") - assert.Equal(t, "a", userDefinedTagName) + require.Equal(t, "a", userDefinedTagName) userDefinedTagName = "" SetTagName("foo") - assert.Equal(t, "foo", userDefinedTagName) + require.Equal(t, "foo", userDefinedTagName) // restore default tag userDefinedTagName = DefaultTagName } @@ -969,19 +999,19 @@ func (r *fakeRows) Scan(dt ...interface{}) (err error) { } func TestScanNotSettable(t *testing.T) { - ass := assert.New(t) + should := require.New(t) err := Scan(&fakeRows{}, nil) - ass.Equal(ErrTargetNotSettable, err) + should.Equal(ErrTargetNotSettable, err) var rows Rows err = Scan(rows, nil) - ass.Equal(ErrTargetNotSettable, err) + should.Equal(ErrTargetNotSettable, err) } func TestScanMapClose(t *testing.T) { var rows Rows _, err := ScanMapClose(rows) - ass := assert.New(t) - ass.Equal(ErrNilRows, err) + should := require.New(t) + should.Equal(ErrNilRows, err) scannn := &fakeRows{ columns: []string{"foo", "bar"}, dataset: [][]interface{}{ @@ -990,22 +1020,18 @@ func TestScanMapClose(t *testing.T) { }, } result, err := ScanMapClose(scannn) - ass.Equal(2, len(result)) - ass.Equal(errCloseForTest.Error(), err.Error()) + should.Equal(2, len(result)) + should.Equal(errCloseForTest.Error(), err.Error()) v, ok := result[0]["foo"] - if !ass.True(ok) { - return - } - ass.Equal(1, v) + should.True(ok) + should.Equal(1, v) v, ok = result[1]["bar"] - if !ass.True(ok) { - return - } - ass.Equal(4, v) + should.True(ok) + should.Equal(4, v) } func TestScanMock(t *testing.T) { - ass := assert.New(t) + should := require.New(t) scannn := &fakeRows{ columns: []string{"name", "age"}, dataset: [][]interface{}{ @@ -1020,15 +1046,15 @@ func TestScanMock(t *testing.T) { var boys []curdBoy userDefinedTagName = "ddb" err := Scan(scannn, &boys) - ass.NoError(err) - ass.Equal("deen", boys[0].Name) - ass.Equal("caibirdme", boys[1].Name) - ass.Equal(23, boys[0].Age) - ass.Equal(24, boys[1].Age) + should.NoError(err) + should.Equal("deen", boys[0].Name) + should.Equal("caibirdme", boys[1].Name) + should.Equal(23, boys[0].Age) + should.Equal(24, boys[1].Age) } func TestScanEmpty(t *testing.T) { - ass := assert.New(t) + should := require.New(t) scannn := &fakeRows{} type curdBoy struct { Name string `ddb:"name"` @@ -1037,11 +1063,11 @@ func TestScanEmpty(t *testing.T) { var boys []curdBoy userDefinedTagName = "ddb" err := Scan(scannn, &boys) - ass.NoError(err) - ass.Equal(0, len(boys)) + should.NoError(err) + should.Equal(0, len(boys)) var boy curdBoy err = Scan(scannn, &boy) - ass.Equal(ErrEmptyResult, err) + should.Equal(ErrEmptyResult, err) } type human struct { @@ -1114,15 +1140,15 @@ func TestUnmarshalByte(t *testing.T) { err: nil, }, } - ass := assert.New(t) + should := require.New(t) for idx, tc := range testCase { var student human if idx >= 2 { student.Extra = &extraInfo{} } err := bind(tc.mapv, &student) - ass.Equal(tc.err, err, "idx:%d", idx) - ass.Equal(tc.expect, student, "idx:%d", idx) + should.Equal(tc.err, err, "idx:%d", idx) + should.Equal(tc.expect, student, "idx:%d", idx) } } @@ -1137,27 +1163,27 @@ func TestScanClose(t *testing.T) { Foo int `ddb:"foo"` Bar int `ddb:"bar"` }{} - ass := assert.New(t) + should := require.New(t) err := ScanClose(rows, &testObj) e, ok := err.(CloseErr) - ass.True(ok) - ass.Equal(errCloseForTest.Error(), e.Error()) - ass.Equal(1, testObj.Foo) - ass.Equal(2, testObj.Bar) + should.True(ok) + should.Equal(errCloseForTest.Error(), e.Error()) + should.Equal(1, testObj.Foo) + should.Equal(2, testObj.Bar) } func TestErrClose(t *testing.T) { - ass := assert.New(t) + should := require.New(t) err := newCloseErr(nil) - ass.Nil(err) + should.Nil(err) err = newCloseErr(errors.New("123")) - ass.NotPanics(func() { - ass.Equal("123", err.Error()) + should.NotPanics(func() { + should.Equal("123", err.Error()) }) } func TestScanMapDecode(t *testing.T) { - ass := assert.New(t) + should := require.New(t) var testCase = []struct { rows Rows expect []map[string]interface{} @@ -1194,7 +1220,7 @@ func TestScanMapDecode(t *testing.T) { } for idx, tc := range testCase { result, err := ScanMapDecode(tc.rows) - ass.Nil(err, "case #%d fail", idx) - ass.Equal(tc.expect, result, "case #%d fail", idx) + should.Nil(err, "case #%d fail", idx) + should.Equal(tc.expect, result, "case #%d fail", idx) } }