Skip to content

Commit a586031

Browse files
committed
feat(istmus): add ddl support to sql->substrait
1 parent b3b5af9 commit a586031

File tree

3 files changed

+232
-9
lines changed

3 files changed

+232
-9
lines changed

isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package io.substrait.isthmus;
22

33
import com.google.common.annotations.VisibleForTesting;
4+
import io.substrait.isthmus.expression.DdlRelBuilder;
45
import io.substrait.isthmus.sql.SubstraitSqlValidator;
6+
import io.substrait.plan.ImmutablePlan;
57
import io.substrait.plan.ImmutablePlan.Builder;
68
import io.substrait.plan.Plan.Version;
79
import io.substrait.plan.PlanProtoConverter;
810
import io.substrait.proto.Plan;
11+
import java.util.ArrayList;
912
import java.util.List;
1013
import org.apache.calcite.plan.hep.HepPlanner;
1114
import org.apache.calcite.plan.hep.HepProgram;
@@ -48,26 +51,66 @@ private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogRea
4851
builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build());
4952

5053
// TODO: consider case in which one sql passes conversion while others don't
51-
sqlToRelNode(sql, validator, catalogReader).stream()
52-
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
53-
.forEach(root -> builder.addRoots(root));
54-
54+
sqlToPlanRoots(sql, validator, catalogReader, builder);
5555
PlanProtoConverter planToProto = new PlanProtoConverter();
5656

5757
return planToProto.toProto(builder.build());
5858
}
5959

60+
void sqlToPlanRoots(
61+
String sql,
62+
SqlValidator validator,
63+
Prepare.CatalogReader catalogReader,
64+
ImmutablePlan.Builder builder)
65+
throws SqlParseException {
66+
67+
SqlParser parser = SqlParser.create(sql, parserConfig);
68+
SqlNodeList parsedList = parser.parseStmtList();
69+
if (parsedList.isEmpty()) {
70+
return;
71+
}
72+
73+
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
74+
DdlRelBuilder ddlRelBuilder =
75+
new DdlRelBuilder(
76+
converter, SqlToSubstrait::getBestExpRelRoot, EXTENSION_COLLECTION, featureBoard);
77+
78+
List<SqlNode> nonDdlNodes = new ArrayList<>();
79+
80+
for (SqlNode sqlNode : parsedList) {
81+
final io.substrait.plan.Plan.Root ddlRoot = sqlNode.accept(ddlRelBuilder);
82+
if (ddlRoot != null) {
83+
builder.addRoots(ddlRoot);
84+
} else {
85+
nonDdlNodes.add(sqlNode);
86+
}
87+
}
88+
89+
if (!nonDdlNodes.isEmpty()) {
90+
91+
SqlNodeList dmlNodes = new SqlNodeList(nonDdlNodes, parsedList.getParserPosition());
92+
93+
List<RelRoot> relRoots = sqlNodesToRelNode(dmlNodes, converter);
94+
relRoots.stream()
95+
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
96+
.forEach(builder::addRoots);
97+
}
98+
}
99+
100+
private List<RelRoot> sqlNodesToRelNode(
101+
final SqlNodeList parsedList, final SqlToRelConverter converter) {
102+
return parsedList.stream()
103+
.map(parsed -> getBestExpRelRoot(converter, parsed))
104+
.collect(java.util.stream.Collectors.toList());
105+
}
106+
60107
private List<RelRoot> sqlToRelNode(
61108
String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
62109
throws SqlParseException {
63110
SqlParser parser = SqlParser.create(sql, parserConfig);
64111
SqlNodeList parsedList = parser.parseStmtList();
65112
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
66-
List<RelRoot> roots =
67-
parsedList.stream()
68-
.map(parsed -> getBestExpRelRoot(converter, parsed))
69-
.collect(java.util.stream.Collectors.toList());
70-
return roots;
113+
return sqlNodesToRelNode(parsedList, converter);
71114
}
72115

73116
@VisibleForTesting
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package io.substrait.isthmus.expression;
2+
3+
import io.substrait.expression.Expression;
4+
import io.substrait.expression.ExpressionCreator;
5+
import io.substrait.extension.SimpleExtension;
6+
import io.substrait.isthmus.FeatureBoard;
7+
import io.substrait.isthmus.SubstraitRelVisitor;
8+
import io.substrait.isthmus.TypeConverter;
9+
import io.substrait.plan.Plan;
10+
import io.substrait.relation.AbstractDdlRel;
11+
import io.substrait.relation.AbstractWriteRel;
12+
import io.substrait.relation.NamedDdl;
13+
import io.substrait.relation.NamedWrite;
14+
import io.substrait.type.NamedStruct;
15+
import java.util.Map;
16+
import java.util.concurrent.ConcurrentHashMap;
17+
import java.util.function.BiFunction;
18+
import java.util.function.Function;
19+
import org.apache.calcite.rel.RelRoot;
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.sql.SqlCall;
22+
import org.apache.calcite.sql.SqlNode;
23+
import org.apache.calcite.sql.ddl.SqlCreateTable;
24+
import org.apache.calcite.sql.ddl.SqlCreateView;
25+
import org.apache.calcite.sql.util.SqlBasicVisitor;
26+
import org.apache.calcite.sql2rel.SqlToRelConverter;
27+
28+
public class DdlRelBuilder extends SqlBasicVisitor<Plan.Root> {
29+
protected final Map<Class<? extends SqlCall>, Function<SqlCall, Plan.Root>> createHandlers =
30+
new ConcurrentHashMap<>();
31+
32+
private final SqlToRelConverter converter;
33+
private final BiFunction<SqlToRelConverter, SqlNode, RelRoot> bestExpRelRootGetter;
34+
private final SimpleExtension.ExtensionCollection extensionCollection;
35+
private final FeatureBoard featureBoard;
36+
37+
public DdlRelBuilder(
38+
final SqlToRelConverter converter,
39+
final BiFunction<SqlToRelConverter, SqlNode, RelRoot> bestExpRelRootGetter,
40+
final SimpleExtension.ExtensionCollection extensionCollection,
41+
final FeatureBoard featureBoard) {
42+
super();
43+
this.converter = converter;
44+
this.bestExpRelRootGetter = bestExpRelRootGetter;
45+
this.extensionCollection = extensionCollection;
46+
this.featureBoard = featureBoard;
47+
48+
createHandlers.put(
49+
SqlCreateTable.class, sqlCall -> handleCreateTable((SqlCreateTable) sqlCall));
50+
createHandlers.put(SqlCreateView.class, sqlCall -> handleCreateView((SqlCreateView) sqlCall));
51+
}
52+
53+
private Function<SqlCall, Plan.Root> findCreateHandler(final SqlCall call) {
54+
Class<?> currentClass = call.getClass();
55+
while (SqlCall.class.isAssignableFrom(currentClass)) {
56+
final Function<SqlCall, Plan.Root> found = createHandlers.get(currentClass);
57+
if (found != null) {
58+
return found;
59+
}
60+
currentClass = currentClass.getSuperclass();
61+
}
62+
return null;
63+
}
64+
65+
@Override
66+
public Plan.Root visit(final SqlCall sqlCall) {
67+
Function<SqlCall, Plan.Root> createHandler = findCreateHandler(sqlCall);
68+
if (createHandler == null) {
69+
return null;
70+
}
71+
72+
return createHandler.apply(sqlCall);
73+
}
74+
75+
private NamedStruct getSchema(final RelRoot queryRelRoot) {
76+
final RelDataType rowType = queryRelRoot.rel.getRowType();
77+
78+
final TypeConverter typeConverter = TypeConverter.DEFAULT;
79+
return typeConverter.toNamedStruct(rowType);
80+
}
81+
82+
protected Plan.Root handleCreateTable(final SqlCreateTable sqlCreateTable) {
83+
if (sqlCreateTable.query == null) {
84+
throw new IllegalArgumentException("Only create table as select statements are supported");
85+
}
86+
87+
final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateTable.query);
88+
89+
NamedStruct schema = getSchema(queryRelRoot);
90+
91+
Plan.Root rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard);
92+
NamedWrite namedWrite =
93+
NamedWrite.builder()
94+
.input(rel.getInput())
95+
.tableSchema(schema)
96+
.operation(AbstractWriteRel.WriteOp.CTAS)
97+
.createMode(AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS)
98+
.outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT)
99+
.names(sqlCreateTable.name.names)
100+
.build();
101+
102+
return Plan.Root.builder().input(namedWrite).build();
103+
}
104+
105+
protected Plan.Root handleCreateView(final SqlCreateView sqlCreateView) {
106+
107+
final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateView.query);
108+
Plan.Root rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard);
109+
final Expression.StructLiteral defaults = ExpressionCreator.struct(false);
110+
111+
final NamedDdl namedDdl =
112+
NamedDdl.builder()
113+
.viewDefinition(rel.getInput())
114+
.tableSchema(getSchema(queryRelRoot))
115+
.tableDefaults(defaults)
116+
.operation(AbstractDdlRel.DdlOp.CREATE)
117+
.object(AbstractDdlRel.DdlObject.VIEW)
118+
.names(sqlCreateView.name.names)
119+
.build();
120+
121+
return Plan.Root.builder().input(namedDdl).build();
122+
}
123+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package io.substrait.isthmus;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
6+
import io.substrait.isthmus.sql.SubstraitSqlValidator;
7+
import io.substrait.plan.ImmutablePlan;
8+
import io.substrait.plan.Plan;
9+
import io.substrait.plan.PlanProtoConverter;
10+
import io.substrait.plan.ProtoPlanConverter;
11+
import org.apache.calcite.prepare.Prepare;
12+
import org.apache.calcite.sql.parser.SqlParseException;
13+
import org.apache.calcite.sql.validate.SqlValidator;
14+
import org.junit.jupiter.api.Test;
15+
16+
class DdlRoundtripTest extends PlanTestBase {
17+
final Prepare.CatalogReader catalogReader =
18+
SubstraitCreateStatementParser.processCreateStatementsToCatalog(
19+
"create table src1 (intcol int, charcol varchar(10))",
20+
"create table src2 (intcol int, charcol varchar(10))");
21+
22+
public DdlRoundtripTest() throws SqlParseException {}
23+
24+
void testSqlToSubstrait(String sqlStatement) throws SqlParseException {
25+
SqlToSubstrait sqlToSubstrait = new SqlToSubstrait();
26+
io.substrait.proto.Plan protoPlan = sqlToSubstrait.execute(sqlStatement, catalogReader);
27+
Plan plan = new ProtoPlanConverter().from(protoPlan);
28+
io.substrait.proto.Plan protoPlan1 = new PlanProtoConverter().toProto(plan);
29+
assertEquals(protoPlan, protoPlan1);
30+
}
31+
32+
void testPlanRoundTrip(String sqlStatement) throws SqlParseException {
33+
SqlToSubstrait sql2subst = new SqlToSubstrait();
34+
ImmutablePlan.Builder builder = io.substrait.plan.Plan.builder();
35+
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
36+
37+
sql2subst.sqlToPlanRoots(sqlStatement, validator, catalogReader, builder);
38+
39+
final Plan plan = builder.build();
40+
assertPlanRoundtrip(plan);
41+
}
42+
43+
@Test
44+
void testCreateTable() throws SqlParseException {
45+
String sql = "create table dst1 as select * from src1";
46+
testSqlToSubstrait(sql);
47+
// TBD: full roundtrip is not possible because there is no relational algebra for DDL
48+
testPlanRoundTrip(sql);
49+
}
50+
51+
@Test
52+
void testCreateView() throws SqlParseException {
53+
String sql = "create view dst1 as select * from src1";
54+
testSqlToSubstrait(sql);
55+
testPlanRoundTrip(sql);
56+
}
57+
}

0 commit comments

Comments
 (0)