Skip to content

Commit

Permalink
Merge pull request #7357 from dolthub/zachmu/prepare2
Browse files Browse the repository at this point in the history
[no-release-notes] refactoring BinaryExpression
  • Loading branch information
zachmu authored Jan 18, 2024
2 parents edde46c + 9c1ee17 commit 3cbb73c
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 69 deletions.
2 changes: 1 addition & 1 deletion go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ require (
github.com/cespare/xxhash v1.1.0
github.com/creasty/defaults v1.6.0
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df
github.com/dolthub/swiss v0.1.0
github.com/goccy/go-json v0.10.2
github.com/google/go-github/v57 v57.0.0
Expand Down
4 changes: 2 additions & 2 deletions go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1 h1:CPdkEWpNyz6H1380wwR+pkxXpBQF7vRTjZ7fb/UCqWs=
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df h1:OmR6U3UvCMEguh1UaXCiK4qasA/tHH3+Ls2NRiEQfjU=
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ=
Expand Down
18 changes: 9 additions & 9 deletions go/libraries/doltcore/sqle/dfunctions/dolt_merge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,28 @@ import (
const DoltMergeBaseFuncName = "dolt_merge_base"

type MergeBase struct {
expression.BinaryExpression
expression.BinaryExpressionStub
}

// NewMergeBase returns a MergeBase sql function.
func NewMergeBase(left, right sql.Expression) sql.Expression {
return &MergeBase{expression.BinaryExpression{Left: left, Right: right}}
return &MergeBase{expression.BinaryExpressionStub{LeftChild: left, RightChild: right}}
}

// Eval implements the sql.Expression interface.
func (d MergeBase) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if _, ok := d.Left.Type().(sql.StringType); !ok {
return nil, sql.ErrInvalidType.New(d.Left.Type())
if _, ok := d.Left().Type().(sql.StringType); !ok {
return nil, sql.ErrInvalidType.New(d.Left().Type())
}
if _, ok := d.Right.Type().(sql.StringType); !ok {
return nil, sql.ErrInvalidType.New(d.Right.Type())
if _, ok := d.Right().Type().(sql.StringType); !ok {
return nil, sql.ErrInvalidType.New(d.Right().Type())
}

leftSpec, err := d.Left.Eval(ctx, row)
leftSpec, err := d.Left().Eval(ctx, row)
if err != nil {
return nil, err
}
rightSpec, err := d.Right.Eval(ctx, row)
rightSpec, err := d.Right().Eval(ctx, row)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -113,7 +113,7 @@ func resolveRefSpecs(ctx *sql.Context, leftSpec, rightSpec string) (left, right

// String implements the sql.Expression interface.
func (d MergeBase) String() string {
return fmt.Sprintf("DOLT_MERGE_BASE(%s,%s)", d.Left.String(), d.Right.String())
return fmt.Sprintf("DOLT_MERGE_BASE(%s,%s)", d.Left().String(), d.Right().String())
}

// Type implements the sql.Expression interface.
Expand Down
2 changes: 1 addition & 1 deletion go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func TestDoltDiffQueryPlans(t *testing.T) {
defer e.Close()

for _, tt := range DoltDiffPlanTests {
enginetest.TestQueryPlan(t, harness, e, tt.Query, tt.ExpectedPlan, false)
enginetest.TestQueryPlan(t, harness, e, tt.Query, tt.ExpectedPlan, sql.DescribeOptions{})
}
}

Expand Down
24 changes: 12 additions & 12 deletions go/libraries/doltcore/sqle/expreval/expression_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,53 +64,53 @@ func ExpressionFuncFromSQLExpressions(vr types.ValueReader, sch schema.Schema, e
func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (ExpressionFunc, error) {
switch typedExpr := exp.(type) {
case *expression.Equals:
return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
return newComparisonFunc(EqualsOp{}, typedExpr, sch)
case *expression.GreaterThan:
return newComparisonFunc(GreaterOp{vr}, typedExpr.BinaryExpression, sch)
return newComparisonFunc(GreaterOp{vr}, typedExpr, sch)
case *expression.GreaterThanOrEqual:
return newComparisonFunc(GreaterEqualOp{vr}, typedExpr.BinaryExpression, sch)
return newComparisonFunc(GreaterEqualOp{vr}, typedExpr, sch)
case *expression.LessThan:
return newComparisonFunc(LessOp{vr}, typedExpr.BinaryExpression, sch)
return newComparisonFunc(LessOp{vr}, typedExpr, sch)
case *expression.LessThanOrEqual:
return newComparisonFunc(LessEqualOp{vr}, typedExpr.BinaryExpression, sch)
return newComparisonFunc(LessEqualOp{vr}, typedExpr, sch)
case *expression.Or:
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left)
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left())

if err != nil {
return nil, err
}

rightFunc, err := getExpFunc(vr, sch, typedExpr.Right)
rightFunc, err := getExpFunc(vr, sch, typedExpr.Right())

if err != nil {
return nil, err
}

return newOrFunc(leftFunc, rightFunc), nil
case *expression.And:
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left)
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left())

if err != nil {
return nil, err
}

rightFunc, err := getExpFunc(vr, sch, typedExpr.Right)
rightFunc, err := getExpFunc(vr, sch, typedExpr.Right())

if err != nil {
return nil, err
}

return newAndFunc(leftFunc, rightFunc), nil
case *expression.InTuple:
return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
return newComparisonFunc(EqualsOp{}, typedExpr, sch)
case *expression.Not:
expFunc, err := getExpFunc(vr, sch, typedExpr.Child)
if err != nil {
return nil, err
}
return newNotFunc(expFunc), nil
case *expression.IsNull:
return newComparisonFunc(EqualsOp{}, expression.BinaryExpression{Left: typedExpr.Child, Right: expression.NewLiteral(nil, gmstypes.Null)}, sch)
return newComparisonFunc(EqualsOp{}, expression.NewNullSafeEquals(typedExpr.Child, expression.NewLiteral(nil, gmstypes.Null)), sch)
}

return nil, errNotImplemented.New(exp.Type().String())
Expand Down Expand Up @@ -175,7 +175,7 @@ func GetComparisonType(be expression.BinaryExpression) ([]*expression.GetField,
var variables []*expression.GetField
var consts []*expression.Literal

for _, curr := range []sql.Expression{be.Left, be.Right} {
for _, curr := range []sql.Expression{be.Left(), be.Right()} {
// need to remove this and handle properly
if conv, ok := curr.(*expression.Convert); ok {
curr = conv.Child
Expand Down
88 changes: 44 additions & 44 deletions go/libraries/doltcore/sqle/expreval/expression_evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,31 @@ func TestGetComparisonType(t *testing.T) {
}{
{
"id = 1",
expression.NewEquals(getId, litOne).BinaryExpression,
expression.NewEquals(getId, litOne),
1,
1,
VariableConstCompare,
false,
},
{
"1 = 1",
expression.NewEquals(litOne, litOne).BinaryExpression,
expression.NewEquals(litOne, litOne),
0,
2,
ConstConstCompare,
false,
},
{
"average > float(median)",
expression.NewGreaterThan(getAverage, expression.NewConvert(getMedian, "float")).BinaryExpression,
expression.NewGreaterThan(getAverage, expression.NewConvert(getMedian, "float")),
2,
0,
VariableVariableCompare,
false,
},
{
" > float(median)",
expression.NewInTuple(getId, expression.NewTuple(litOne, litTwo, litThree)).BinaryExpression,
expression.NewInTuple(getId, expression.NewTuple(litOne, litTwo, litThree)),
1,
3,
VariableInLiteralList,
Expand Down Expand Up @@ -245,10 +245,10 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare int literals -1 and -1",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewLiteral(int8(-1), gmstypes.Int8),
Right: expression.NewLiteral(int64(-1), gmstypes.Int64),
},
be: expression.NewEquals(
expression.NewLiteral(int8(-1), gmstypes.Int8),
expression.NewLiteral(int64(-1), gmstypes.Int64),
),
expectNewErr: false,
testVals: []funcTestVal{
{
Expand All @@ -270,10 +270,10 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare int literals -5 and 5",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewLiteral(int8(-5), gmstypes.Int8),
Right: expression.NewLiteral(uint8(5), gmstypes.Uint8),
},
be: expression.NewEquals(
expression.NewLiteral(int8(-5), gmstypes.Int8),
expression.NewLiteral(uint8(5), gmstypes.Uint8),
),
expectNewErr: false,
testVals: []funcTestVal{
{
Expand All @@ -295,10 +295,10 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare string literals b and a",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewLiteral("b", gmstypes.Text),
Right: expression.NewLiteral("a", gmstypes.Text),
},
be: expression.NewEquals(
expression.NewLiteral("b", gmstypes.Text),
expression.NewLiteral("a", gmstypes.Text),
),
expectNewErr: false,
testVals: []funcTestVal{
{
Expand All @@ -320,10 +320,10 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare int value to numeric string literals",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(0, gmstypes.Int64, "col0", false),
Right: expression.NewLiteral("1", gmstypes.Text),
},
be: expression.NewEquals(
expression.NewGetField(0, gmstypes.Int64, "col0", false),
expression.NewLiteral("1", gmstypes.Text),
),
expectNewErr: false,
testVals: []funcTestVal{
{
Expand Down Expand Up @@ -352,10 +352,10 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare date value to date string literals",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(2, gmstypes.Datetime, "date", false),
Right: expression.NewLiteral("2000-01-01", gmstypes.Text),
},
be: expression.NewEquals(
expression.NewGetField(2, gmstypes.Datetime, "date", false),
expression.NewLiteral("2000-01-01", gmstypes.Text),
),
expectNewErr: false,
testVals: []funcTestVal{
{
Expand Down Expand Up @@ -396,10 +396,10 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare col1 and col0",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(1, gmstypes.Int64, "col1", false),
Right: expression.NewGetField(0, gmstypes.Int64, "col0", false),
},
be: expression.NewEquals(
expression.NewGetField(1, gmstypes.Int64, "col1", false),
expression.NewGetField(0, gmstypes.Int64, "col0", false),
),
expectNewErr: false,
testVals: []funcTestVal{
{
Expand Down Expand Up @@ -446,40 +446,40 @@ func TestNewComparisonFunc(t *testing.T) {
{
name: "compare const and unknown column variable",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(0, gmstypes.Int64, "unknown", false),
Right: expression.NewLiteral("1", gmstypes.Text),
},
be: expression.NewEquals(
expression.NewGetField(0, gmstypes.Int64, "unknown", false),
expression.NewLiteral("1", gmstypes.Text),
),
expectNewErr: true,
testVals: []funcTestVal{},
},
{
name: "compare variables with first unknown",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(0, gmstypes.Int64, "unknown", false),
Right: expression.NewGetField(1, gmstypes.Int64, "col1", false),
},
be: expression.NewEquals(
expression.NewGetField(0, gmstypes.Int64, "unknown", false),
expression.NewGetField(1, gmstypes.Int64, "col1", false),
),
expectNewErr: true,
testVals: []funcTestVal{},
},
{
name: "compare variables with second unknown",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(1, gmstypes.Int64, "col1", false),
Right: expression.NewGetField(0, gmstypes.Int64, "unknown", false),
},
be: expression.NewEquals(
expression.NewGetField(1, gmstypes.Int64, "col1", false),
expression.NewGetField(0, gmstypes.Int64, "unknown", false),
),
expectNewErr: true,
testVals: []funcTestVal{},
},
{
name: "variable with literal that can't be converted",
sch: testSch,
be: expression.BinaryExpression{
Left: expression.NewGetField(0, gmstypes.Int64, "col0", false),
Right: expression.NewLiteral("not a number", gmstypes.Text),
},
be: expression.NewEquals(
expression.NewGetField(0, gmstypes.Int64, "col0", false),
expression.NewLiteral("not a number", gmstypes.Text),
),
expectNewErr: true,
testVals: []funcTestVal{},
},
Expand Down

0 comments on commit 3cbb73c

Please sign in to comment.