Skip to content

Commit b232a95

Browse files
committed
feat(isthmus): udf support for substrait<->calcite
1 parent b3b5af9 commit b232a95

19 files changed

+626
-32
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package io.substrait.isthmus;
2+
3+
import io.substrait.extension.SimpleExtension;
4+
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
5+
import java.util.List;
6+
import java.util.Locale;
7+
import java.util.Set;
8+
import java.util.stream.Collectors;
9+
10+
public class ExtensionUtils {
11+
12+
public static SimpleExtension.ExtensionCollection getDynamicExtensions(
13+
SimpleExtension.ExtensionCollection extensions) {
14+
Set<String> knownFunctionNames =
15+
SubstraitOperatorTable.INSTANCE.getOperatorList().stream()
16+
.map(op -> op.getName().toLowerCase(Locale.ROOT))
17+
.collect(Collectors.toSet());
18+
19+
List<SimpleExtension.ScalarFunctionVariant> customFunctions =
20+
extensions.scalarFunctions().stream()
21+
.filter(f -> !knownFunctionNames.contains(f.name().toLowerCase(Locale.ROOT)))
22+
.collect(Collectors.toList());
23+
24+
return SimpleExtension.ExtensionCollection.builder()
25+
.scalarFunctions(customFunctions)
26+
// TODO: handle aggregates and other functions
27+
.build();
28+
}
29+
30+
public static SimpleExtension.ExtensionCollection loadExtensions(List<String> yamlFunctionFiles) {
31+
SimpleExtension.ExtensionCollection allExtensions = SimpleExtension.loadDefaults();
32+
if (yamlFunctionFiles != null && !yamlFunctionFiles.isEmpty()) {
33+
allExtensions = allExtensions.merge(SimpleExtension.load(yamlFunctionFiles));
34+
}
35+
return allExtensions;
36+
}
37+
}
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
package io.substrait.isthmus;
2+
3+
import io.substrait.extension.SimpleExtension;
4+
import io.substrait.function.ParameterizedType;
5+
import io.substrait.function.ParameterizedTypeVisitor;
6+
import io.substrait.function.TypeExpression;
7+
import io.substrait.type.Type;
8+
import io.substrait.type.TypeExpressionEvaluator;
9+
import java.util.List;
10+
import java.util.stream.Collectors;
11+
import java.util.stream.Stream;
12+
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
13+
import org.apache.calcite.rel.type.RelDataType;
14+
import org.apache.calcite.rel.type.RelDataTypeFactory;
15+
import org.apache.calcite.sql.SqlFunction;
16+
import org.apache.calcite.sql.SqlFunctionCategory;
17+
import org.apache.calcite.sql.SqlKind;
18+
import org.apache.calcite.sql.SqlOperator;
19+
import org.apache.calcite.sql.SqlOperatorBinding;
20+
import org.apache.calcite.sql.type.OperandTypes;
21+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
22+
import org.apache.calcite.sql.type.SqlTypeFamily;
23+
import org.apache.calcite.sql.type.SqlTypeName;
24+
25+
public final class SimpleExtensionToSqlOperator {
26+
27+
private static final RelDataTypeFactory DEFAULT_TYPE_FACTORY =
28+
new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
29+
30+
private SimpleExtensionToSqlOperator() {}
31+
32+
public static List<SqlOperator> from(SimpleExtension.ExtensionCollection collection) {
33+
return from(collection, DEFAULT_TYPE_FACTORY);
34+
}
35+
36+
public static List<SqlOperator> from(
37+
SimpleExtension.ExtensionCollection collection, RelDataTypeFactory typeFactory) {
38+
TypeConverter typeConverter = TypeConverter.DEFAULT;
39+
return Stream.concat(
40+
collection.scalarFunctions().stream(), collection.aggregateFunctions().stream())
41+
.map(function -> toSqlFunction(function, typeFactory, typeConverter))
42+
.collect(Collectors.toList());
43+
}
44+
45+
private static SqlFunction toSqlFunction(
46+
SimpleExtension.Function function,
47+
RelDataTypeFactory typeFactory,
48+
TypeConverter typeConverter) {
49+
List<SimpleExtension.ValueArgument> requiredArgs =
50+
function.args().stream()
51+
.filter(SimpleExtension.Argument::required)
52+
.filter(t -> t instanceof SimpleExtension.ValueArgument)
53+
.map(t -> (SimpleExtension.ValueArgument) t)
54+
.collect(Collectors.toList());
55+
56+
List<SqlTypeFamily> argFamilies =
57+
requiredArgs.stream()
58+
.map(arg -> arg.value().accept(new CalciteTypeVisitor()).getFamily())
59+
.collect(Collectors.toList());
60+
61+
SqlReturnTypeInference returnTypeInference =
62+
new SubstraitReturnTypeInference(function, typeFactory, typeConverter);
63+
64+
return new SqlFunction(
65+
function.name(),
66+
SqlKind.OTHER_FUNCTION,
67+
returnTypeInference,
68+
null,
69+
OperandTypes.family(argFamilies),
70+
SqlFunctionCategory.USER_DEFINED_FUNCTION);
71+
}
72+
73+
private static class SubstraitReturnTypeInference implements SqlReturnTypeInference {
74+
75+
private final SimpleExtension.Function function;
76+
private final RelDataTypeFactory typeFactory;
77+
private final TypeConverter typeConverter;
78+
79+
private SubstraitReturnTypeInference(
80+
SimpleExtension.Function function,
81+
RelDataTypeFactory typeFactory,
82+
TypeConverter typeConverter) {
83+
this.function = function;
84+
this.typeFactory = typeFactory;
85+
this.typeConverter = typeConverter;
86+
}
87+
88+
@Override
89+
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
90+
List<Type> substraitArgTypes =
91+
opBinding.collectOperandTypes().stream()
92+
.map(typeConverter::toSubstrait)
93+
.collect(Collectors.toList());
94+
95+
TypeExpression returnExpression = function.returnType();
96+
Type resolvedSubstraitType =
97+
TypeExpressionEvaluator.evaluateExpression(
98+
returnExpression, function.args(), substraitArgTypes);
99+
100+
return typeConverter.toCalcite(typeFactory, resolvedSubstraitType);
101+
}
102+
}
103+
104+
private static class CalciteTypeVisitor
105+
extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor<
106+
SqlTypeName, RuntimeException> {
107+
108+
private CalciteTypeVisitor() {
109+
super("Type not supported for Calcite conversion.");
110+
}
111+
112+
@Override
113+
public SqlTypeName visit(Type.Bool expr) {
114+
return SqlTypeName.BOOLEAN;
115+
}
116+
117+
@Override
118+
public SqlTypeName visit(Type.I8 expr) {
119+
return SqlTypeName.TINYINT;
120+
}
121+
122+
@Override
123+
public SqlTypeName visit(Type.I16 expr) {
124+
return SqlTypeName.SMALLINT;
125+
}
126+
127+
@Override
128+
public SqlTypeName visit(Type.I32 expr) {
129+
return SqlTypeName.INTEGER;
130+
}
131+
132+
@Override
133+
public SqlTypeName visit(Type.I64 expr) {
134+
return SqlTypeName.BIGINT;
135+
}
136+
137+
@Override
138+
public SqlTypeName visit(Type.FP32 expr) {
139+
return SqlTypeName.FLOAT;
140+
}
141+
142+
@Override
143+
public SqlTypeName visit(Type.FP64 expr) {
144+
return SqlTypeName.DOUBLE;
145+
}
146+
147+
@Override
148+
public SqlTypeName visit(Type.Str expr) {
149+
return SqlTypeName.VARCHAR;
150+
}
151+
152+
@Override
153+
public SqlTypeName visit(Type.Binary expr) {
154+
return SqlTypeName.VARBINARY;
155+
}
156+
157+
@Override
158+
public SqlTypeName visit(Type.Date expr) {
159+
return SqlTypeName.DATE;
160+
}
161+
162+
@Override
163+
public SqlTypeName visit(Type.Time expr) {
164+
return SqlTypeName.TIME;
165+
}
166+
167+
@Override
168+
public SqlTypeName visit(Type.TimestampTZ expr) {
169+
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
170+
}
171+
172+
@Override
173+
public SqlTypeName visit(Type.Timestamp expr) {
174+
return SqlTypeName.TIMESTAMP;
175+
}
176+
177+
@Override
178+
public SqlTypeName visit(Type.IntervalYear year) {
179+
return SqlTypeName.INTERVAL_YEAR_MONTH;
180+
}
181+
182+
@Override
183+
public SqlTypeName visit(Type.IntervalDay day) {
184+
return SqlTypeName.INTERVAL_DAY;
185+
}
186+
187+
@Override
188+
public SqlTypeName visit(Type.UUID expr) {
189+
return SqlTypeName.VARCHAR;
190+
}
191+
192+
@Override
193+
public SqlTypeName visit(Type.Struct struct) {
194+
return SqlTypeName.ROW;
195+
}
196+
197+
@Override
198+
public SqlTypeName visit(Type.ListType listType) {
199+
return SqlTypeName.ARRAY;
200+
}
201+
202+
@Override
203+
public SqlTypeName visit(Type.Map map) {
204+
return SqlTypeName.MAP;
205+
}
206+
207+
@Override
208+
public SqlTypeName visit(ParameterizedType.FixedChar expr) {
209+
return SqlTypeName.CHAR;
210+
}
211+
212+
@Override
213+
public SqlTypeName visit(ParameterizedType.VarChar expr) {
214+
return SqlTypeName.VARCHAR;
215+
}
216+
217+
@Override
218+
public SqlTypeName visit(ParameterizedType.FixedBinary expr) {
219+
return SqlTypeName.BINARY;
220+
}
221+
222+
@Override
223+
public SqlTypeName visit(ParameterizedType.Decimal expr) {
224+
return SqlTypeName.DECIMAL;
225+
}
226+
227+
@Override
228+
public SqlTypeName visit(ParameterizedType.Struct expr) {
229+
return SqlTypeName.ROW;
230+
}
231+
232+
@Override
233+
public SqlTypeName visit(ParameterizedType.ListType expr) {
234+
return SqlTypeName.ARRAY;
235+
}
236+
237+
@Override
238+
public SqlTypeName visit(ParameterizedType.Map expr) {
239+
return SqlTypeName.MAP;
240+
}
241+
242+
@Override
243+
public SqlTypeName visit(ParameterizedType.PrecisionTimestamp expr) {
244+
return SqlTypeName.TIMESTAMP;
245+
}
246+
247+
@Override
248+
public SqlTypeName visit(ParameterizedType.PrecisionTimestampTZ expr) {
249+
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
250+
}
251+
252+
@Override
253+
public SqlTypeName visit(ParameterizedType.PrecisionTime expr) {
254+
return SqlTypeName.TIME;
255+
}
256+
257+
@Override
258+
public SqlTypeName visit(ParameterizedType.IntervalDay expr) {
259+
return SqlTypeName.INTERVAL_DAY;
260+
}
261+
262+
@Override
263+
public SqlTypeName visit(ParameterizedType.IntervalCompound expr) {
264+
// TODO: double check
265+
return SqlTypeName.INTERVAL_DAY_HOUR;
266+
}
267+
268+
@Override
269+
public SqlTypeName visit(ParameterizedType.StringLiteral expr) {
270+
String type = expr.value().toUpperCase();
271+
272+
if (type.startsWith("ANY")) {
273+
return SqlTypeName.ANY;
274+
}
275+
276+
switch (type) {
277+
case "BOOLEAN":
278+
return SqlTypeName.BOOLEAN;
279+
case "I8":
280+
return SqlTypeName.TINYINT;
281+
case "I16":
282+
return SqlTypeName.SMALLINT;
283+
case "I32":
284+
return SqlTypeName.INTEGER;
285+
case "I64":
286+
return SqlTypeName.BIGINT;
287+
case "FP32":
288+
return SqlTypeName.FLOAT;
289+
case "FP64":
290+
return SqlTypeName.DOUBLE;
291+
case "STRING":
292+
return SqlTypeName.VARCHAR;
293+
case "BINARY":
294+
return SqlTypeName.VARBINARY;
295+
case "TIMESTAMP":
296+
return SqlTypeName.TIMESTAMP;
297+
case "TIMESTAMP_TZ":
298+
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
299+
case "DATE":
300+
return SqlTypeName.DATE;
301+
case "TIME":
302+
return SqlTypeName.TIME;
303+
case "UUID":
304+
return SqlTypeName.VARCHAR;
305+
default:
306+
if (type.startsWith("DECIMAL")) {
307+
return SqlTypeName.DECIMAL;
308+
}
309+
if (type.startsWith("STRUCT")) {
310+
return SqlTypeName.ROW;
311+
}
312+
if (type.startsWith("LIST")) {
313+
return SqlTypeName.ARRAY;
314+
}
315+
return super.visit(expr);
316+
}
317+
}
318+
}
319+
}

isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
import org.apache.calcite.sql2rel.SqlToRelConverter;
2020

2121
class SqlConverterBase {
22-
protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
23-
SimpleExtension.loadDefaults();
24-
22+
protected final SimpleExtension.ExtensionCollection extensionCollection;
2523
final RelDataTypeFactory factory;
2624
final RelOptCluster relOptCluster;
2725
final CalciteConnectionConfig config;
@@ -32,7 +30,8 @@ class SqlConverterBase {
3230
protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
3331
final FeatureBoard featureBoard;
3432

35-
protected SqlConverterBase(FeatureBoard features) {
33+
protected SqlConverterBase(
34+
FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) {
3635
this.factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
3736
this.config =
3837
CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false");
@@ -51,5 +50,11 @@ protected SqlConverterBase(FeatureBoard features) {
5150
.withUnquotedCasing(featureBoard.unquotedCasing())
5251
.withParserFactory(SqlDdlParserImpl.FACTORY)
5352
.withConformance(SqlConformanceEnum.LENIENT);
53+
54+
this.extensionCollection = extensionCollection;
55+
}
56+
57+
protected SqlConverterBase(FeatureBoard features) {
58+
this(features, SimpleExtension.loadDefaults());
5459
}
5560
}

0 commit comments

Comments
 (0)