Skip to content

Commit 108bb43

Browse files
committed
feat(istmus): add ddl support to sql->substrait
# Conflicts: # isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
1 parent ac80ee5 commit 108bb43

File tree

3 files changed

+223
-5
lines changed

3 files changed

+223
-5
lines changed

isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package io.substrait.isthmus;
22

33
import com.google.common.annotations.VisibleForTesting;
4+
import io.substrait.isthmus.expression.DdlRelBuilder;
45
import io.substrait.isthmus.sql.SubstraitSqlValidator;
56
import io.substrait.plan.ImmutablePlan.Builder;
67
import io.substrait.plan.Plan;
78
import io.substrait.plan.Plan.Version;
89
import io.substrait.plan.PlanProtoConverter;
10+
import java.util.ArrayList;
911
import java.util.List;
1012
import org.apache.calcite.plan.hep.HepPlanner;
1113
import org.apache.calcite.plan.hep.HepProgram;
@@ -64,10 +66,8 @@ public Plan convert(String sql, Prepare.CatalogReader catalogReader) throws SqlP
6466
Builder builder = io.substrait.plan.Plan.builder();
6567
builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build());
6668

67-
// TODO: consider case in which one sql passes conversion while others don't
68-
sqlToRelNode(sql, catalogReader).stream()
69-
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
70-
.forEach(root -> builder.addRoots(root));
69+
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
70+
sqlToPlanRoots(sql, validator, catalogReader, builder);
7171

7272
return builder.build();
7373
}
@@ -86,6 +86,49 @@ List<RelRoot> sqlToRelNode(String sql, Prepare.CatalogReader catalogReader)
8686
return roots;
8787
}
8888

89+
protected void sqlToPlanRoots(
90+
String sql, SqlValidator validator, Prepare.CatalogReader catalogReader, Builder builder)
91+
throws SqlParseException {
92+
93+
SqlParser parser = SqlParser.create(sql, parserConfig);
94+
SqlNodeList parsedList = parser.parseStmtList();
95+
if (parsedList.isEmpty()) {
96+
return;
97+
}
98+
99+
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
100+
DdlRelBuilder ddlRelBuilder =
101+
new DdlRelBuilder(
102+
converter, SqlToSubstrait::getBestExpRelRoot, EXTENSION_COLLECTION, featureBoard);
103+
104+
List<SqlNode> nonDdlNodes = new ArrayList<>();
105+
106+
for (SqlNode sqlNode : parsedList) {
107+
final io.substrait.plan.Plan.Root ddlRoot = sqlNode.accept(ddlRelBuilder);
108+
if (ddlRoot != null) {
109+
builder.addRoots(ddlRoot);
110+
} else {
111+
nonDdlNodes.add(sqlNode);
112+
}
113+
}
114+
115+
if (!nonDdlNodes.isEmpty()) {
116+
SqlNodeList dmlNodes = new SqlNodeList(nonDdlNodes, parsedList.getParserPosition());
117+
118+
List<RelRoot> relRoots = sqlNodesToRelNode(dmlNodes, converter);
119+
relRoots.stream()
120+
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
121+
.forEach(builder::addRoots);
122+
}
123+
}
124+
125+
private List<RelRoot> sqlNodesToRelNode(
126+
final SqlNodeList parsedList, final SqlToRelConverter converter) {
127+
return parsedList.stream()
128+
.map(parsed -> getBestExpRelRoot(converter, parsed))
129+
.collect(java.util.stream.Collectors.toList());
130+
}
131+
89132
protected SqlToRelConverter createSqlToRelConverter(
90133
SqlValidator validator, Prepare.CatalogReader catalogReader) {
91134
SqlToRelConverter converter =
@@ -99,7 +142,7 @@ protected SqlToRelConverter createSqlToRelConverter(
99142
return converter;
100143
}
101144

102-
protected RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
145+
protected static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
103146
RelRoot root = converter.convertQuery(parsed, true, true);
104147
{
105148
// RelBuilder seems to implicitly use the rule below,
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package io.substrait.isthmus.expression;
2+
3+
import io.substrait.expression.Expression;
4+
import io.substrait.expression.ExpressionCreator;
5+
import io.substrait.extension.SimpleExtension;
6+
import io.substrait.isthmus.FeatureBoard;
7+
import io.substrait.isthmus.SubstraitRelVisitor;
8+
import io.substrait.isthmus.TypeConverter;
9+
import io.substrait.plan.Plan;
10+
import io.substrait.relation.AbstractDdlRel;
11+
import io.substrait.relation.AbstractWriteRel;
12+
import io.substrait.relation.NamedDdl;
13+
import io.substrait.relation.NamedWrite;
14+
import io.substrait.type.NamedStruct;
15+
import java.util.Map;
16+
import java.util.concurrent.ConcurrentHashMap;
17+
import java.util.function.BiFunction;
18+
import java.util.function.Function;
19+
import org.apache.calcite.rel.RelRoot;
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.sql.SqlCall;
22+
import org.apache.calcite.sql.SqlNode;
23+
import org.apache.calcite.sql.ddl.SqlCreateTable;
24+
import org.apache.calcite.sql.ddl.SqlCreateView;
25+
import org.apache.calcite.sql.util.SqlBasicVisitor;
26+
import org.apache.calcite.sql2rel.SqlToRelConverter;
27+
28+
public class DdlRelBuilder extends SqlBasicVisitor<Plan.Root> {
29+
protected final Map<Class<? extends SqlCall>, Function<SqlCall, Plan.Root>> createHandlers =
30+
new ConcurrentHashMap<>();
31+
32+
private final SqlToRelConverter converter;
33+
private final BiFunction<SqlToRelConverter, SqlNode, RelRoot> bestExpRelRootGetter;
34+
private final SimpleExtension.ExtensionCollection extensionCollection;
35+
private final FeatureBoard featureBoard;
36+
37+
public DdlRelBuilder(
38+
final SqlToRelConverter converter,
39+
final BiFunction<SqlToRelConverter, SqlNode, RelRoot> bestExpRelRootGetter,
40+
final SimpleExtension.ExtensionCollection extensionCollection,
41+
final FeatureBoard featureBoard) {
42+
super();
43+
this.converter = converter;
44+
this.bestExpRelRootGetter = bestExpRelRootGetter;
45+
this.extensionCollection = extensionCollection;
46+
this.featureBoard = featureBoard;
47+
48+
createHandlers.put(
49+
SqlCreateTable.class, sqlCall -> handleCreateTable((SqlCreateTable) sqlCall));
50+
createHandlers.put(SqlCreateView.class, sqlCall -> handleCreateView((SqlCreateView) sqlCall));
51+
}
52+
53+
private Function<SqlCall, Plan.Root> findCreateHandler(final SqlCall call) {
54+
Class<?> currentClass = call.getClass();
55+
while (SqlCall.class.isAssignableFrom(currentClass)) {
56+
final Function<SqlCall, Plan.Root> found = createHandlers.get(currentClass);
57+
if (found != null) {
58+
return found;
59+
}
60+
currentClass = currentClass.getSuperclass();
61+
}
62+
return null;
63+
}
64+
65+
@Override
66+
public Plan.Root visit(final SqlCall sqlCall) {
67+
Function<SqlCall, Plan.Root> createHandler = findCreateHandler(sqlCall);
68+
if (createHandler == null) {
69+
return null;
70+
}
71+
72+
return createHandler.apply(sqlCall);
73+
}
74+
75+
private NamedStruct getSchema(final RelRoot queryRelRoot) {
76+
final RelDataType rowType = queryRelRoot.rel.getRowType();
77+
78+
final TypeConverter typeConverter = TypeConverter.DEFAULT;
79+
return typeConverter.toNamedStruct(rowType);
80+
}
81+
82+
protected Plan.Root handleCreateTable(final SqlCreateTable sqlCreateTable) {
83+
if (sqlCreateTable.query == null) {
84+
throw new IllegalArgumentException("Only create table as select statements are supported");
85+
}
86+
87+
final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateTable.query);
88+
89+
NamedStruct schema = getSchema(queryRelRoot);
90+
91+
Plan.Root rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard);
92+
NamedWrite namedWrite =
93+
NamedWrite.builder()
94+
.input(rel.getInput())
95+
.tableSchema(schema)
96+
.operation(AbstractWriteRel.WriteOp.CTAS)
97+
.createMode(AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS)
98+
.outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT)
99+
.names(sqlCreateTable.name.names)
100+
.build();
101+
102+
return Plan.Root.builder().input(namedWrite).build();
103+
}
104+
105+
protected Plan.Root handleCreateView(final SqlCreateView sqlCreateView) {
106+
107+
final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateView.query);
108+
Plan.Root rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard);
109+
final Expression.StructLiteral defaults = ExpressionCreator.struct(false);
110+
111+
final NamedDdl namedDdl =
112+
NamedDdl.builder()
113+
.viewDefinition(rel.getInput())
114+
.tableSchema(getSchema(queryRelRoot))
115+
.tableDefaults(defaults)
116+
.operation(AbstractDdlRel.DdlOp.CREATE)
117+
.object(AbstractDdlRel.DdlObject.VIEW)
118+
.names(sqlCreateView.name.names)
119+
.build();
120+
121+
return Plan.Root.builder().input(namedDdl).build();
122+
}
123+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package io.substrait.isthmus;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
6+
import io.substrait.plan.Plan;
7+
import io.substrait.plan.PlanProtoConverter;
8+
import io.substrait.plan.ProtoPlanConverter;
9+
import org.apache.calcite.prepare.Prepare;
10+
import org.apache.calcite.sql.parser.SqlParseException;
11+
import org.junit.jupiter.api.Test;
12+
13+
class DdlRoundtripTest extends PlanTestBase {
14+
final Prepare.CatalogReader catalogReader =
15+
SubstraitCreateStatementParser.processCreateStatementsToCatalog(
16+
"create table src1 (intcol int, charcol varchar(10))",
17+
"create table src2 (intcol int, charcol varchar(10))");
18+
19+
public DdlRoundtripTest() throws SqlParseException {
20+
super();
21+
}
22+
23+
void testSqlToSubstrait(String sqlStatement) throws SqlParseException {
24+
SqlToSubstrait sqlToSubstrait = new SqlToSubstrait();
25+
io.substrait.proto.Plan protoPlan = sqlToSubstrait.execute(sqlStatement, catalogReader);
26+
Plan plan = new ProtoPlanConverter().from(protoPlan);
27+
io.substrait.proto.Plan protoPlan1 = new PlanProtoConverter().toProto(plan);
28+
assertEquals(protoPlan, protoPlan1);
29+
}
30+
31+
void testPlanRoundTrip(String sqlStatement) throws SqlParseException {
32+
SqlToSubstrait sql2subst = new SqlToSubstrait();
33+
final Plan plan = sql2subst.convert(sqlStatement, catalogReader);
34+
35+
assertPlanRoundtrip(plan);
36+
}
37+
38+
@Test
39+
void testCreateTable() throws SqlParseException {
40+
String sql = "create table dst1 as select * from src1";
41+
testSqlToSubstrait(sql);
42+
// TBD: full roundtrip is not possible because there is no relational algebra for DDL
43+
testPlanRoundTrip(sql);
44+
}
45+
46+
@Test
47+
void testCreateView() throws SqlParseException {
48+
String sql = "create view dst1 as select * from src1";
49+
testSqlToSubstrait(sql);
50+
testPlanRoundTrip(sql);
51+
}
52+
}

0 commit comments

Comments
 (0)