diff --git a/internal/codec.go b/internal/codec.go index 9c46060..4562d20 100644 --- a/internal/codec.go +++ b/internal/codec.go @@ -3,20 +3,21 @@ package internal type ValueType string const ( - IntValueType ValueType = "int64" - StringValueType ValueType = "string" - BytesValueType ValueType = "bytes" - FloatValueType ValueType = "float" - NumericValueType ValueType = "numeric" - BoolValueType ValueType = "bool" - JsonValueType ValueType = "json" - ArrayValueType ValueType = "array" - StructValueType ValueType = "struct" - DateValueType ValueType = "date" - DatetimeValueType ValueType = "datetime" - TimeValueType ValueType = "time" - TimestampValueType ValueType = "timestamp" - IntervalValueType ValueType = "interval" + IntValueType ValueType = "int64" + StringValueType ValueType = "string" + BytesValueType ValueType = "bytes" + FloatValueType ValueType = "float" + NumericValueType ValueType = "numeric" + BigNumericValueType ValueType = "bignumeric" + BoolValueType ValueType = "bool" + JsonValueType ValueType = "json" + ArrayValueType ValueType = "array" + StructValueType ValueType = "struct" + DateValueType ValueType = "date" + DatetimeValueType ValueType = "datetime" + TimeValueType ValueType = "time" + TimestampValueType ValueType = "timestamp" + IntervalValueType ValueType = "interval" ) type ValueLayout struct { diff --git a/internal/decoder.go b/internal/decoder.go index 14e91f4..1495b8c 100644 --- a/internal/decoder.go +++ b/internal/decoder.go @@ -49,7 +49,11 @@ func decodeFromValueLayout(layout *ValueLayout) (Value, error) { case NumericValueType: r := new(big.Rat) r.SetString(layout.Body) - return (*NumericValue)(r), nil + return &NumericValue{Rat: r}, nil + case BigNumericValueType: + r := new(big.Rat) + r.SetString(layout.Body) + return &NumericValue{Rat: r, isBigNumeric: true}, nil case DateValueType: t, err := parseDate(layout.Body) if err != nil { diff --git a/internal/encoder.go b/internal/encoder.go index aac9eda..80fa46b 100644 --- a/internal/encoder.go +++ b/internal/encoder.go @@ -292,7 +292,10 @@ func numericValueFromLiteral(lit string) (*NumericValue, error) { numericLit := matches[0][1] r := new(big.Rat) r.SetString(numericLit) - return (*NumericValue)(r), nil + if strings.Contains("BIGNUMERIC", lit) { + return &NumericValue{Rat: r, isBigNumeric: true}, nil + } + return &NumericValue{Rat: r}, nil } func jsonValueFromLiteral(lit string) (JsonValue, error) { @@ -471,12 +474,18 @@ func CastValue(t types.Type, v Value) (Value, error) { ret.m[key] = casted } return ret, nil - case types.NUMERIC, types.BIG_NUMERIC: + case types.NUMERIC: r, err := v.ToRat() if err != nil { return nil, err } - return (*NumericValue)(r), nil + return &NumericValue{Rat: r}, nil + case types.BIG_NUMERIC: + r, err := v.ToRat() + if err != nil { + return nil, err + } + return &NumericValue{Rat: r, isBigNumeric: true}, nil case types.JSON: j, err := v.ToJSON() if err != nil { @@ -591,10 +600,16 @@ func valueLayoutFromValue(v Value) (*ValueLayout, error) { Body: base64.StdEncoding.EncodeToString([]byte(vv)), }, nil case *NumericValue: - b, err := (*big.Rat)(vv).MarshalText() + b, err := vv.Rat.MarshalText() if err != nil { return nil, err } + if vv.isBigNumeric { + return &ValueLayout{ + Header: BigNumericValueType, + Body: string(b), + }, nil + } return &ValueLayout{ Header: NumericValueType, Body: string(b), diff --git a/internal/function_bind.go b/internal/function_bind.go index fdaf3a2..3417407 100644 --- a/internal/function_bind.go +++ b/internal/function_bind.go @@ -584,7 +584,7 @@ func bindParseBigNumeric(args ...Value) (Value, error) { if err != nil { return nil, err } - return PARSE_NUMERIC(numeric) + return PARSE_BIGNUMERIC(numeric) } func bindFarmFingerprint(args ...Value) (Value, error) { diff --git a/internal/function_numeric.go b/internal/function_numeric.go index f14be0f..cc73315 100644 --- a/internal/function_numeric.go +++ b/internal/function_numeric.go @@ -10,5 +10,13 @@ func PARSE_NUMERIC(numeric string) (Value, error) { if _, ok := r.SetString(numeric); !ok { return nil, fmt.Errorf("unexpected numeric literal: %s", numeric) } - return (*NumericValue)(r), nil + return &NumericValue{Rat: r}, nil +} + +func PARSE_BIGNUMERIC(numeric string) (Value, error) { + r := new(big.Rat) + if _, ok := r.SetString(numeric); !ok { + return nil, fmt.Errorf("unexpected numeric literal: %s", numeric) + } + return &NumericValue{Rat: r, isBigNumeric: true}, nil } diff --git a/internal/rows.go b/internal/rows.go index 73cf3db..1e25b37 100644 --- a/internal/rows.go +++ b/internal/rows.go @@ -5,9 +5,7 @@ import ( "database/sql/driver" "fmt" "io" - "math/big" "reflect" - "strings" "time" "github.com/goccy/go-json" @@ -221,27 +219,13 @@ func (r *Rows) assignInterfaceValue(src Value, dst reflect.Value, typ *Type) err if err != nil { return err } - r := new(big.Rat) - if _, ok := r.SetString(s); !ok { - return fmt.Errorf("unexpected numeric value: %s", s) - } - f := r.FloatString(9) - f = strings.TrimRight(f, "0") - f = strings.TrimRight(f, ".") - dst.Set(reflect.ValueOf(f)) + dst.Set(reflect.ValueOf(s)) case types.BIG_NUMERIC: s, err := src.ToString() if err != nil { return err } - r := new(big.Rat) - if _, ok := r.SetString(s); !ok { - return fmt.Errorf("unexpected bignumeric value: %s", s) - } - f := r.FloatString(38) - f = strings.TrimRight(f, "0") - f = strings.TrimRight(f, ".") - dst.Set(reflect.ValueOf(f)) + dst.Set(reflect.ValueOf(s)) case types.DATE: date, err := src.ToJSON() if err != nil { diff --git a/internal/value.go b/internal/value.go index 1b5220e..6104a88 100644 --- a/internal/value.go +++ b/internal/value.go @@ -587,37 +587,43 @@ func (fv FloatValue) Interface() interface{} { return float64(fv) } -type NumericValue big.Rat +type NumericValue struct { + *big.Rat + isBigNumeric bool +} func (nv *NumericValue) Add(v Value) (Value, error) { z := new(big.Rat) - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return nil, err } - return (*NumericValue)(z.Add(x, y)), nil + nv.Rat = z.Add(x, y) + return nv, nil } func (nv *NumericValue) Sub(v Value) (Value, error) { z := new(big.Rat) - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return nil, err } zy := new(big.Rat) - return (*NumericValue)(z.Add(x, zy.Neg(y))), nil + nv.Rat = z.Add(x, zy.Neg(y)) + return nv, nil } func (nv *NumericValue) Mul(v Value) (Value, error) { z := new(big.Rat) - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return nil, err } - return (*NumericValue)(z.Mul(x, y)), nil + nv.Rat = z.Mul(x, y) + return nv, nil } func (nv *NumericValue) Div(v Value) (ret Value, e error) { @@ -627,17 +633,18 @@ func (nv *NumericValue) Div(v Value) (ret Value, e error) { } }() z := new(big.Rat) - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return nil, err } zy := new(big.Rat) - return (*NumericValue)(z.Mul(x, zy.Inv(y))), nil + nv.Rat = z.Mul(x, zy.Inv(y)) + return nv, nil } func (nv *NumericValue) EQ(v Value) (bool, error) { - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return false, err @@ -646,7 +653,7 @@ func (nv *NumericValue) EQ(v Value) (bool, error) { } func (nv *NumericValue) GT(v Value) (bool, error) { - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return false, err @@ -655,7 +662,7 @@ func (nv *NumericValue) GT(v Value) (bool, error) { } func (nv *NumericValue) GTE(v Value) (bool, error) { - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return false, err @@ -664,7 +671,7 @@ func (nv *NumericValue) GTE(v Value) (bool, error) { } func (nv *NumericValue) LT(v Value) (bool, error) { - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return false, err @@ -673,7 +680,7 @@ func (nv *NumericValue) LT(v Value) (bool, error) { } func (nv *NumericValue) LTE(v Value) (bool, error) { - x := (*big.Rat)(nv) + x := nv.Rat y, err := v.ToRat() if err != nil { return false, err @@ -682,24 +689,36 @@ func (nv *NumericValue) LTE(v Value) (bool, error) { } func (nv *NumericValue) ToInt64() (int64, error) { - return (*big.Rat)(nv).Num().Int64(), nil + return nv.Rat.Num().Int64(), nil +} + +func (nv *NumericValue) toString() string { + var v string + if nv.isBigNumeric { + v = nv.Rat.FloatString(38) + } else { + v = nv.Rat.FloatString(9) + } + v = strings.TrimRight(v, "0") + v = strings.TrimRight(v, ".") + return v } func (nv *NumericValue) ToString() (string, error) { - return (*big.Rat)(nv).RatString(), nil + return nv.toString(), nil } func (nv *NumericValue) ToBytes() ([]byte, error) { - return []byte((*big.Rat)(nv).RatString()), nil + return []byte(nv.toString()), nil } func (nv *NumericValue) ToFloat64() (float64, error) { - f, _ := (*big.Rat)(nv).Float64() + f, _ := nv.Rat.Float64() return f, nil } func (nv *NumericValue) ToBool() (bool, error) { - v := (*big.Rat)(nv).Num().Int64() + v := nv.Rat.Num().Int64() if v == 1 { return true, nil } else if v == 0 { @@ -717,7 +736,7 @@ func (nv *NumericValue) ToStruct() (*StructValue, error) { } func (nv *NumericValue) ToJSON() (string, error) { - return (*big.Rat)(nv).RatString(), nil + return nv.toString(), nil } func (nv *NumericValue) ToTime() (time.Time, error) { @@ -725,15 +744,15 @@ func (nv *NumericValue) ToTime() (time.Time, error) { } func (nv *NumericValue) ToRat() (*big.Rat, error) { - return (*big.Rat)(nv), nil + return nv.Rat, nil } func (nv *NumericValue) Format(verb rune) string { - return (*big.Rat)(nv).RatString() + return nv.toString() } func (nv *NumericValue) Interface() interface{} { - f, _ := (*big.Rat)(nv).Float64() + f, _ := nv.Rat.Float64() return f } diff --git a/query_test.go b/query_test.go index 6679e29..daff25d 100644 --- a/query_test.go +++ b/query_test.go @@ -3087,6 +3087,11 @@ SELECT query: `SELECT PARSE_BIGNUMERIC("123.45"), PARSE_BIGNUMERIC("123.456E37"), PARSE_BIGNUMERIC("1.123456789012345678901234567890123456789")`, expectedRows: [][]interface{}{{"123.45", "1234560000000000000000000000000000000000", "1.12345678901234567890123456789012345679"}}, }, + { + name: "cast numeric and bignumeric to string", + query: `SELECT cast(PARSE_NUMERIC("123.456") as STRING), cast(PARSE_BIGNUMERIC("123.456") as STRING)`, + expectedRows: [][]interface{}{{"123.456", "123.456"}}, + }, // uuid functions {