diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java new file mode 100644 index 000000000..e9852a7ce --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java @@ -0,0 +1,37 @@ +package io.substrait.isthmus; + +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitOperatorTable; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; + +public class ExtensionUtils { + + public static SimpleExtension.ExtensionCollection getDynamicExtensions( + SimpleExtension.ExtensionCollection extensions) { + Set knownFunctionNames = + SubstraitOperatorTable.INSTANCE.getOperatorList().stream() + .map(op -> op.getName().toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()); + + List customFunctions = + extensions.scalarFunctions().stream() + .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase(Locale.ROOT))) + .collect(Collectors.toList()); + + return SimpleExtension.ExtensionCollection.builder() + .scalarFunctions(customFunctions) + // TODO: handle aggregates and other functions + .build(); + } + + public static SimpleExtension.ExtensionCollection loadExtensions(List yamlFunctionFiles) { + SimpleExtension.ExtensionCollection allExtensions = SimpleExtension.loadDefaults(); + if (yamlFunctionFiles != null && !yamlFunctionFiles.isEmpty()) { + allExtensions = allExtensions.merge(SimpleExtension.load(yamlFunctionFiles)); + } + return allExtensions; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java b/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java new file mode 100644 index 000000000..4ee3a419e --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java @@ -0,0 +1,342 @@ +package io.substrait.isthmus; + +import io.substrait.extension.SimpleExtension; +import io.substrait.function.ParameterizedType; +import io.substrait.function.ParameterizedTypeVisitor; +import io.substrait.function.TypeExpression; +import io.substrait.type.Type; +import io.substrait.type.TypeExpressionEvaluator; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; + +public final class SimpleExtensionToSqlOperator { + + private static final RelDataTypeFactory DEFAULT_TYPE_FACTORY = + new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM); + + private SimpleExtensionToSqlOperator() {} + + public static List from(SimpleExtension.ExtensionCollection collection) { + return from(collection, DEFAULT_TYPE_FACTORY); + } + + public static List from( + SimpleExtension.ExtensionCollection collection, RelDataTypeFactory typeFactory) { + TypeConverter typeConverter = TypeConverter.DEFAULT; + return Stream.concat( + collection.scalarFunctions().stream(), collection.aggregateFunctions().stream()) + .map(function -> toSqlFunction(function, typeFactory, typeConverter)) + .collect(Collectors.toList()); + } + + private static SqlFunction toSqlFunction( + SimpleExtension.Function function, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + + List argFamilies = new ArrayList<>(); + + for (SimpleExtension.Argument arg : function.requiredArguments()) { + if (arg instanceof SimpleExtension.ValueArgument) { + SimpleExtension.ValueArgument valueArg = (SimpleExtension.ValueArgument) arg; + SqlTypeName typeName = valueArg.value().accept(new CalciteTypeVisitor()); + argFamilies.add(typeName.getFamily()); + } else if (arg instanceof SimpleExtension.EnumArgument) { + // Treat an EnumArgument as a required string literal. + argFamilies.add(SqlTypeFamily.STRING); + } + } + + SqlReturnTypeInference returnTypeInference = + new SubstraitReturnTypeInference(function, typeFactory, typeConverter); + + return new SqlFunction( + function.name(), + SqlKind.OTHER_FUNCTION, + returnTypeInference, + null, + OperandTypes.family(argFamilies), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } + + private static class SubstraitReturnTypeInference implements SqlReturnTypeInference { + + private final SimpleExtension.Function function; + private final RelDataTypeFactory typeFactory; + private final TypeConverter typeConverter; + + private SubstraitReturnTypeInference( + SimpleExtension.Function function, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + this.function = function; + this.typeFactory = typeFactory; + this.typeConverter = typeConverter; + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + List substraitArgTypes = + opBinding.collectOperandTypes().stream() + .map(typeConverter::toSubstrait) + .collect(Collectors.toList()); + + TypeExpression returnExpression = function.returnType(); + Type resolvedSubstraitType = + TypeExpressionEvaluator.evaluateExpression( + returnExpression, function.args(), substraitArgTypes); + + boolean finalIsNullable; + switch (function.nullability()) { + case MIRROR: + // If any input is nullable, the output is nullable. + finalIsNullable = + opBinding.collectOperandTypes().stream().anyMatch(RelDataType::isNullable); + break; + case DISCRETE: + // The function can return null even if inputs are not null. + finalIsNullable = true; + break; + case DECLARED_OUTPUT: + default: + // Use the nullability declared on the resolved Substrait type. + finalIsNullable = resolvedSubstraitType.nullable(); + break; + } + + RelDataType baseCalciteType = typeConverter.toCalcite(typeFactory, resolvedSubstraitType); + + return typeFactory.createTypeWithNullability(baseCalciteType, finalIsNullable); + } + } + + private static class CalciteTypeVisitor + extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor< + SqlTypeName, RuntimeException> { + + private CalciteTypeVisitor() { + super("Type not supported for Calcite conversion."); + } + + @Override + public SqlTypeName visit(Type.Bool expr) { + return SqlTypeName.BOOLEAN; + } + + @Override + public SqlTypeName visit(Type.I8 expr) { + return SqlTypeName.TINYINT; + } + + @Override + public SqlTypeName visit(Type.I16 expr) { + return SqlTypeName.SMALLINT; + } + + @Override + public SqlTypeName visit(Type.I32 expr) { + return SqlTypeName.INTEGER; + } + + @Override + public SqlTypeName visit(Type.I64 expr) { + return SqlTypeName.BIGINT; + } + + @Override + public SqlTypeName visit(Type.FP32 expr) { + return SqlTypeName.FLOAT; + } + + @Override + public SqlTypeName visit(Type.FP64 expr) { + return SqlTypeName.DOUBLE; + } + + @Override + public SqlTypeName visit(Type.Str expr) { + return SqlTypeName.VARCHAR; + } + + @Override + public SqlTypeName visit(Type.Binary expr) { + return SqlTypeName.VARBINARY; + } + + @Override + public SqlTypeName visit(Type.Date expr) { + return SqlTypeName.DATE; + } + + @Override + public SqlTypeName visit(Type.Time expr) { + return SqlTypeName.TIME; + } + + @Override + public SqlTypeName visit(Type.TimestampTZ expr) { + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + } + + @Override + public SqlTypeName visit(Type.Timestamp expr) { + return SqlTypeName.TIMESTAMP; + } + + @Override + public SqlTypeName visit(Type.IntervalYear year) { + return SqlTypeName.INTERVAL_YEAR_MONTH; + } + + @Override + public SqlTypeName visit(Type.IntervalDay day) { + return SqlTypeName.INTERVAL_DAY; + } + + @Override + public SqlTypeName visit(Type.UUID expr) { + return SqlTypeName.VARCHAR; + } + + @Override + public SqlTypeName visit(Type.Struct struct) { + return SqlTypeName.ROW; + } + + @Override + public SqlTypeName visit(Type.ListType listType) { + return SqlTypeName.ARRAY; + } + + @Override + public SqlTypeName visit(Type.Map map) { + return SqlTypeName.MAP; + } + + @Override + public SqlTypeName visit(ParameterizedType.FixedChar expr) { + return SqlTypeName.CHAR; + } + + @Override + public SqlTypeName visit(ParameterizedType.VarChar expr) { + return SqlTypeName.VARCHAR; + } + + @Override + public SqlTypeName visit(ParameterizedType.FixedBinary expr) { + return SqlTypeName.BINARY; + } + + @Override + public SqlTypeName visit(ParameterizedType.Decimal expr) { + return SqlTypeName.DECIMAL; + } + + @Override + public SqlTypeName visit(ParameterizedType.Struct expr) { + return SqlTypeName.ROW; + } + + @Override + public SqlTypeName visit(ParameterizedType.ListType expr) { + return SqlTypeName.ARRAY; + } + + @Override + public SqlTypeName visit(ParameterizedType.Map expr) { + return SqlTypeName.MAP; + } + + @Override + public SqlTypeName visit(ParameterizedType.PrecisionTimestamp expr) { + return SqlTypeName.TIMESTAMP; + } + + @Override + public SqlTypeName visit(ParameterizedType.PrecisionTimestampTZ expr) { + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + } + + @Override + public SqlTypeName visit(ParameterizedType.PrecisionTime expr) { + return SqlTypeName.TIME; + } + + @Override + public SqlTypeName visit(ParameterizedType.IntervalDay expr) { + return SqlTypeName.INTERVAL_DAY; + } + + @Override + public SqlTypeName visit(ParameterizedType.IntervalCompound expr) { + // TODO: double check + return SqlTypeName.INTERVAL_DAY_HOUR; + } + + @Override + public SqlTypeName visit(ParameterizedType.StringLiteral expr) { + String type = expr.value().toUpperCase(); + + if (type.startsWith("ANY")) { + return SqlTypeName.ANY; + } + + switch (type) { + case "BOOLEAN": + return SqlTypeName.BOOLEAN; + case "I8": + return SqlTypeName.TINYINT; + case "I16": + return SqlTypeName.SMALLINT; + case "I32": + return SqlTypeName.INTEGER; + case "I64": + return SqlTypeName.BIGINT; + case "FP32": + return SqlTypeName.FLOAT; + case "FP64": + return SqlTypeName.DOUBLE; + case "STRING": + return SqlTypeName.VARCHAR; + case "BINARY": + return SqlTypeName.VARBINARY; + case "TIMESTAMP": + return SqlTypeName.TIMESTAMP; + case "TIMESTAMP_TZ": + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + case "DATE": + return SqlTypeName.DATE; + case "TIME": + return SqlTypeName.TIME; + case "UUID": + return SqlTypeName.VARCHAR; + default: + if (type.startsWith("DECIMAL")) { + return SqlTypeName.DECIMAL; + } + if (type.startsWith("STRUCT")) { + return SqlTypeName.ROW; + } + if (type.startsWith("LIST")) { + return SqlTypeName.ARRAY; + } + return super.visit(expr); + } + } + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index e60df0b68..2ed056e6d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -19,9 +19,7 @@ import org.apache.calcite.sql2rel.SqlToRelConverter; class SqlConverterBase { - protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - SimpleExtension.loadDefaults(); - + protected final SimpleExtension.ExtensionCollection extensionCollection; final RelDataTypeFactory factory; final RelOptCluster relOptCluster; final CalciteConnectionConfig config; @@ -32,7 +30,8 @@ class SqlConverterBase { protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); final FeatureBoard featureBoard; - protected SqlConverterBase(FeatureBoard features) { + protected SqlConverterBase( + FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) { this.factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM); this.config = CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); @@ -51,5 +50,11 @@ protected SqlConverterBase(FeatureBoard features) { .withUnquotedCasing(featureBoard.unquotedCasing()) .withParserFactory(SqlDdlParserImpl.FACTORY) .withConformance(SqlConformanceEnum.LENIENT); + + this.extensionCollection = extensionCollection; + } + + protected SqlConverterBase(FeatureBoard features) { + this(features, SimpleExtension.loadDefaults()); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index c32fab07c..fc144b090 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -34,12 +34,12 @@ public class SqlExpressionToSubstrait extends SqlConverterBase { protected final RexExpressionConverter rexConverter; public SqlExpressionToSubstrait() { - this(FEATURES_DEFAULT, EXTENSION_COLLECTION); + this(FEATURES_DEFAULT, SimpleExtension.loadDefaults()); } public SqlExpressionToSubstrait( FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { - super(features); + super(features, extensions); ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter(extensions.scalarFunctions(), factory); this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 3e19ca58c..9116d8ce4 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,22 +1,42 @@ package io.substrait.isthmus; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.ImmutablePlan.Builder; import io.substrait.plan.Plan; import io.substrait.plan.Plan.Version; import io.substrait.plan.PlanProtoConverter; +import java.util.List; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.util.SqlOperatorTables; /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { + private final SqlOperatorTable operatorTable; public SqlToSubstrait() { - this(null); + this(SimpleExtension.loadDefaults(), null); } public SqlToSubstrait(FeatureBoard features) { - super(features); + this(SimpleExtension.loadDefaults(), features); + } + + public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + super(features, extensions); + + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + List generatedDynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, this.factory); + + this.operatorTable = + SqlOperatorTables.chain( + SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators)); } /** @@ -53,8 +73,8 @@ public Plan convert(String sqlStatements, Prepare.CatalogReader catalogReader) builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build()); // TODO: consider case in which one sql passes conversion while others don't - SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader).stream() - .map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard)) + SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, operatorTable).stream() + .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard)) .forEach(root -> builder.addRoots(root)); return builder.build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 801110c5f..101468a3e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -1,7 +1,5 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; - import com.google.common.collect.ImmutableList; import com.google.common.collect.Range; import com.google.common.collect.RangeMap; @@ -12,6 +10,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; +import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.AbstractRelVisitor; @@ -97,7 +96,7 @@ public SubstraitRelNodeConverter( this( typeFactory, relBuilder, - new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory), + createScalarFunctionConverter(extensions, typeFactory), new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), TypeConverter.DEFAULT); @@ -139,11 +138,45 @@ public SubstraitRelNodeConverter( this.expressionRexConverter.setRelNodeConverter(this); } + private static ScalarFunctionConverter createScalarFunctionConverter( + SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { + + java.util.Set knownFunctionNames = + FunctionMappings.SCALAR_SIGS.stream() + .map(FunctionMappings.Sig::name) + .collect(Collectors.toSet()); + + List dynamicFunctions = + extensions.scalarFunctions().stream() + .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase())) + .collect(Collectors.toList()); + + List additionalSignatures; + if (dynamicFunctions.isEmpty()) { + additionalSignatures = Collections.emptyList(); + } else { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build(); + + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + + additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + } + + return new ScalarFunctionConverter( + extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT); + } + public static RelNode convert( Rel relRoot, RelOptCluster relOptCluster, Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig) { + SqlParser.Config parserConfig, + SimpleExtension.ExtensionCollection extensions) { RelBuilder relBuilder = RelBuilder.create( Frameworks.newConfigBuilder() @@ -154,8 +187,7 @@ public static RelNode convert( .build()); return relRoot.accept( - new SubstraitRelNodeConverter( - EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder), + new SubstraitRelNodeConverter(extensions, relOptCluster.getTypeFactory(), relBuilder), Context.newContext()); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index cc41e84a0..4bc7e3e9e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -8,6 +8,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.CallConverters; +import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.LiteralConverter; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; @@ -53,6 +54,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -78,10 +80,25 @@ public SubstraitRelVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + + List additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); this.typeConverter = TypeConverter.DEFAULT; - ArrayList converters = new ArrayList(); + ArrayList converters = new ArrayList<>(); converters.addAll(CallConverters.defaults(typeConverter)); - converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)); + converters.add( + new ScalarFunctionConverter( + extensions.scalarFunctions(), + additionalSignatures, + typeFactory, + TypeConverter.DEFAULT)); converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); this.aggregateFunctionConverter = new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index 421b45317..e327ab007 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import io.substrait.extension.SimpleExtension; import io.substrait.relation.Rel; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; @@ -10,7 +11,12 @@ public SubstraitToSql() { super(FEATURES_DEFAULT); } + public SubstraitToSql(SimpleExtension.ExtensionCollection extensions) { + super(FEATURES_DEFAULT, extensions); + } + public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) { - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalog, parserConfig); + return SubstraitRelNodeConverter.convert( + relRoot, relOptCluster, catalog, parserConfig, extensionCollection); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 8f0ccb000..691d45fa7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -24,7 +24,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.IdentityHashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -85,8 +84,7 @@ public FunctionConverter( .collect( Multimaps.toMultimap( FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create)); - IdentityHashMap matcherMap = - new IdentityHashMap(); + Map matcherMap = new HashMap<>(); for (String key : alm.keySet()) { Collection sigs = calciteOperators.get(key); if (sigs.isEmpty()) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java index 499f33f1d..24a2fb0f4 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java @@ -15,6 +15,7 @@ import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.SqlToRelConverter; @@ -41,6 +42,23 @@ public static RelRoot convertQuery(String sqlStatement, Prepare.CatalogReader ca return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster()); } + /** + * Converts a SQL statement to a Calcite {@link RelRoot}. + * + * @param sqlStatement a SQL statement string + * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in + * the SQL statement + * @param operatorTable the {@link SqlOperatorTable} for dynamic operators + * @return a {@link RelRoot} corresponding to the given SQL statement + * @throws SqlParseException if there is an error while parsing the SQL statement + */ + public static RelRoot convertQuery( + String sqlStatement, Prepare.CatalogReader catalogReader, SqlOperatorTable operatorTable) + throws SqlParseException { + SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable); + return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster()); + } + /** * Converts a SQL statement to a Calcite {@link RelRoot}. * @@ -72,6 +90,24 @@ public static RelRoot convertQuery( return relRoots.get(0); } + /** + * Converts one or more SQL statements to a List of {@link RelRoot}, with one {@link RelRoot} per + * statement. + * + * @param sqlStatements a string containing one or more SQL statements + * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in + * the SQL statements + * @param operatorTable the {@link SqlOperatorTable} for dynamic operators + * @return a list of {@link RelRoot}s corresponding to the given SQL statements + * @throws SqlParseException if there is an error while parsing the SQL statements + */ + public static List convertQueries( + String sqlStatements, Prepare.CatalogReader catalogReader, SqlOperatorTable operatorTable) + throws SqlParseException { + SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable); + return convertQueries(sqlStatements, catalogReader, validator, createDefaultRelOptCluster()); + } + /** * Converts one or more SQL statements to a List of {@link RelRoot}, with one {@link RelRoot} per * statement. diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java index eddcb1d0f..a4fe518b9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java @@ -2,6 +2,7 @@ import io.substrait.isthmus.calcite.SubstraitOperatorTable; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorImpl; @@ -12,4 +13,8 @@ public class SubstraitSqlValidator extends SqlValidatorImpl { public SubstraitSqlValidator(Prepare.CatalogReader catalogReader) { super(SubstraitOperatorTable.INSTANCE, catalogReader, catalogReader.getTypeFactory(), CONFIG); } + + public SubstraitSqlValidator(Prepare.CatalogReader catalogReader, SqlOperatorTable opTable) { + super(opTable, catalogReader, catalogReader.getTypeFactory(), CONFIG); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index 448b332df..8136d6544 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -11,6 +10,7 @@ import io.substrait.expression.Expression.Literal; import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; import io.substrait.isthmus.expression.ExpressionRexConverter; import io.substrait.isthmus.expression.RexExpressionConverter; @@ -35,6 +35,8 @@ import org.junit.jupiter.api.Test; public class CalciteLiteralTest extends CalciteObjs { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + SimpleExtension.loadDefaults(); private final ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), type); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java index 553b20395..64308c3ef 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -1,12 +1,12 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ImmutableAggregateFunctionInvocation; +import io.substrait.extension.SimpleExtension; import io.substrait.relation.Aggregate; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -17,6 +17,8 @@ import org.junit.jupiter.api.Test; public class ComplexAggregateTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + SimpleExtension.loadDefaults(); final TypeCreator R = TypeCreator.of(false); SubstraitBuilder b = new SubstraitBuilder(extensions); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index 948fcfc23..ee548a525 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -1,10 +1,10 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; +import io.substrait.extension.SimpleExtension; import io.substrait.relation.Rel; import io.substrait.type.TypeCreator; import java.io.PrintWriter; @@ -20,6 +20,9 @@ public class ComplexSortTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + SimpleExtension.loadDefaults(); + final TypeCreator R = TypeCreator.of(false); SubstraitBuilder b = new SubstraitBuilder(extensions); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index d462ea05d..bd45dcff1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -1,8 +1,8 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.Plan; @@ -13,6 +13,9 @@ public class NameRoundtripTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + SimpleExtension.loadDefaults(); + @Test void preserveNamesFromSql() throws Exception { String createStatement = "CREATE TABLE foo(a BIGINT, b BIGINT)"; diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index fd9572b2e..ae7005baf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -1,8 +1,8 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import java.io.IOException; import org.apache.calcite.plan.hep.HepPlanner; @@ -16,6 +16,9 @@ public class OptimizerIntegrationTest extends PlanTestBase { + private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = + SimpleExtension.loadDefaults(); + @Test void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException { String query = diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 56173f1ea..4d3292ba5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -41,12 +41,12 @@ import org.apache.calcite.tools.RelBuilder; public class PlanTestBase { - protected final SimpleExtension.ExtensionCollection extensions = - SqlConverterBase.EXTENSION_COLLECTION; + protected final SimpleExtension.ExtensionCollection extensions; + protected final RelCreator creator = new RelCreator(); protected final RelBuilder builder = creator.createRelBuilder(); protected final RelDataTypeFactory typeFactory = creator.typeFactory(); - protected final SubstraitBuilder substraitBuilder = new SubstraitBuilder(extensions); + protected final SubstraitBuilder substraitBuilder; protected static final TypeCreator R = TypeCreator.of(false); protected static final TypeCreator N = TypeCreator.of(true); @@ -66,6 +66,15 @@ public class PlanTestBase { protected static CalciteCatalogReader TPCDS_CATALOG = PlanTestBase.schemaToCatalog("tpcds", TPCDS_SCHEMA); + protected PlanTestBase() { + this(SimpleExtension.loadDefaults()); + } + + protected PlanTestBase(SimpleExtension.ExtensionCollection extensions) { + this.extensions = extensions; + this.substraitBuilder = new SubstraitBuilder(extensions); + } + public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } @@ -147,6 +156,39 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( return relRoot2; } + protected RelRoot assertSqlSubstraitRelRoundTripWorkaroundOptimizer( + String query, Prepare.CatalogReader catalogReader) throws Exception { + // sql <--> substrait round trip test. + // Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same. + // Return list of sql -> Substrait rel -> Calcite rel. + + SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + + SqlToSubstrait s = new SqlToSubstrait(extensions, null); + + // 1. SQL -> Calcite RelRoot + Plan plan1 = s.convert(query, catalogReader); + + // 2. Calcite RelRoot -> Substrait Rel + Plan.Root pojo1 = plan1.getRoots().get(0); + + // 3. Substrait Rel -> Calcite RelNode + RelRoot relRoot2 = substraitToCalcite.convert(pojo1); + + // 4. Calcite RelNode -> Substrait Rel + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions); + + // Here pojo1 and pojo2 can be different because of different default optimization + // rules between SqlNode->RelRoot conversion (Sql->Substrait) and + // RelBuilder/RexBuilder (Substrait->Sql). + // Therefore, substrait plans passed through conversion to calcite should be compared + RelRoot relRoot3 = substraitToCalcite.convert(pojo2); + Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, extensions); + + assertEquals(pojo2, pojo3); + return relRoot2; + } + @Beta protected void assertFullRoundTrip(String query) throws SqlParseException { assertFullRoundTrip(query, TPCH_CATALOG); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java new file mode 100644 index 000000000..e34e945f4 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java @@ -0,0 +1,46 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.extension.SimpleExtension; +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.sql.SqlOperandCountRange; +import org.apache.calcite.sql.SqlOperator; +import org.junit.jupiter.api.Test; + +public class SimpleExtensionToSqlOperatorTest { + + @Test + void test() throws IOException { + String customFunctionPath = "/extensions/scalar_functions_custom.yaml"; + + SimpleExtension.ExtensionCollection customExtensions = + SimpleExtension.load( + customFunctionPath, + SimpleExtensionToSqlOperatorTest.class.getResourceAsStream(customFunctionPath)); + + List operators = SimpleExtensionToSqlOperator.from(customExtensions); + + Optional function = + operators.stream() + .filter(op -> op.getName().equalsIgnoreCase("REGEXP_EXTRACT")) + .findFirst(); + + assertTrue(function.isPresent(), "The REGEXP_EXTRACT function should be present."); + + SqlOperator op = function.get(); + System.out.println("Successfully found and verified Custom UDF:"); + System.out.printf(" - Name: %s%n", op.getName()); + + SqlOperandCountRange operandCountRange = op.getOperandCountRange(); + assertEquals(2, operandCountRange.getMin(), "Function should require 2 arguments."); + assertEquals(2, operandCountRange.getMax(), "Function should require 2 arguments."); + System.out.printf(" - Argument Count: %d%n", operandCountRange.getMin()); + + assertNotNull(op.getOperandTypeChecker(), "Operand type checker should not be null."); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java new file mode 100644 index 000000000..4b4905a35 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java @@ -0,0 +1,44 @@ +package io.substrait.isthmus; + +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import java.util.List; +import org.apache.calcite.prepare.Prepare; +import org.junit.jupiter.api.Test; + +public class UdfSqlSubstraitTest extends PlanTestBase { + + private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml"; + + UdfSqlSubstraitTest() { + super(loadExtensions(List.of(CUSTOM_FUNCTION_PATH))); + } + + @Test + public void customUdfTest() throws Exception { + + final Prepare.CatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog( + "CREATE TABLE t(x VARCHAR NOT NULL)"); + + assertSqlSubstraitRelRoundTripWorkaroundOptimizer( + "SELECT regexp_extract(x, 'ab') from t", catalogReader); + assertSqlSubstraitRelRoundTripWorkaroundOptimizer( + "SELECT format_text('UPPER', x) FROM t", catalogReader); + assertSqlSubstraitRelRoundTripWorkaroundOptimizer( + "SELECT system_property_get(x) FROM t", catalogReader); + assertSqlSubstraitRelRoundTripWorkaroundOptimizer( + "SELECT safe_divide(10,0) FROM t", catalogReader); + } + + private static SimpleExtension.ExtensionCollection loadExtensions( + List yamlFunctionFiles) { + SimpleExtension.ExtensionCollection extensions = SimpleExtension.loadDefaults(); + if (yamlFunctionFiles != null && !yamlFunctionFiles.isEmpty()) { + SimpleExtension.ExtensionCollection customExtensions = + SimpleExtension.load(yamlFunctionFiles); + extensions = extensions.merge(customExtensions); + } + return extensions; + } +} diff --git a/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml b/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml new file mode 100644 index 000000000..8c152df71 --- /dev/null +++ b/isthmus/src/test/resources/extensions/scalar_functions_custom.yaml @@ -0,0 +1,44 @@ +%YAML 1.2 +--- +scalar_functions: + - name: "regexp_extract" + impls: + - args: + - name: "text" + value: string + - name: "pattern" + value: string + return: string + + - name: "format_text" + description: "Formats text based on a mode. The output is nullable if the input is." + impls: + - args: + - name: "mode" +# options: ["UPPER", "LOWER"] + value: string + - name: "input_text" +# options: ["UPPER", "LOWER"] + value: string + return: string + nullability: MIRROR + + - name: "system_property_get" + description: "Safely gets a system property. Always returns a nullable string." + impls: + - args: + - name: "property_name" + value: string + return: string? + nullability: DECLARED_OUTPUT + + - name: "safe_divide" + description: "Performs division, returning NULL if the denominator is zero." + impls: + - args: + - name: "numerator" + value: i32 + - name: "denominator" + value: i32 + return: fp32? + nullability: DISCRETE