diff --git a/velox/functions/sparksql/specialforms/SparkCastExpr.cpp b/velox/functions/sparksql/specialforms/SparkCastExpr.cpp index de1d6b30bfa..86c3edb5c92 100644 --- a/velox/functions/sparksql/specialforms/SparkCastExpr.cpp +++ b/velox/functions/sparksql/specialforms/SparkCastExpr.cpp @@ -16,6 +16,7 @@ #include "velox/functions/sparksql/specialforms/SparkCastExpr.h" +#include "velox/expression/SpecialFormRegistry.h" #include "velox/functions/sparksql/SparkQueryConfig.h" namespace facebook::velox::functions::sparksql { @@ -26,6 +27,70 @@ bool isIntegralType(const TypePtr& type) { type == BIGINT(); } +exec::ExprPtr makeSparkCastExpr( + const TypePtr& type, + exec::ExprPtr&& input, + bool trackCpuUsage, + bool isTryCast, + bool allowOverflow, + const core::QueryConfig& config) { + return std::make_shared( + type, + std::move(input), + trackCpuUsage, + isTryCast, + std::make_shared(config, allowOverflow)); +} + +class SparkAnsiCastCallToSpecialForm : public exec::CastCallToSpecialForm { + public: + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const core::QueryConfig& config) override { + VELOX_CHECK_EQ( + compiledChildren.size(), + 1, + "ANSI CAST statements expect exactly 1 argument, received {}.", + compiledChildren.size()); + + const auto& fromType = compiledChildren[0]->type(); + const bool isTryCast = + !SparkCastCallToSpecialForm::isAnsiSupported(fromType, type); + return makeSparkCastExpr( + type, + std::move(compiledChildren[0]), + trackCpuUsage, + isTryCast, + isTryCast, + config); + } +}; + +class SparkLegacyCastCallToSpecialForm : public exec::CastCallToSpecialForm { + public: + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const core::QueryConfig& config) override { + VELOX_CHECK_EQ( + compiledChildren.size(), + 1, + "LEGACY CAST statements expect exactly 1 argument, received {}.", + compiledChildren.size()); + + return makeSparkCastExpr( + type, + std::move(compiledChildren[0]), + trackCpuUsage, + true, + true, + config); + } +}; + } // namespace bool SparkCastCallToSpecialForm::isAnsiSupported( @@ -96,4 +161,13 @@ exec::ExprPtr SparkTryCastCallToSpecialForm::constructSpecialForm( std::make_shared(config, false)); } +void registerSparkCastModeSpecialForms( + const std::string& ansiCastName, + const std::string& legacyCastName) { + exec::registerFunctionCallToSpecialForm( + ansiCastName, std::make_unique()); + exec::registerFunctionCallToSpecialForm( + legacyCastName, std::make_unique()); +} + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/specialforms/SparkCastExpr.h b/velox/functions/sparksql/specialforms/SparkCastExpr.h index aa3fee83f07..5df98fafee2 100644 --- a/velox/functions/sparksql/specialforms/SparkCastExpr.h +++ b/velox/functions/sparksql/specialforms/SparkCastExpr.h @@ -16,6 +16,8 @@ #pragma once +#include + #include "velox/expression/CastExpr.h" #include "velox/functions/sparksql/specialforms/SparkCastHooks.h" @@ -45,7 +47,6 @@ class SparkCastCallToSpecialForm : public exec::CastCallToSpecialForm { bool trackCpuUsage, const core::QueryConfig& config) override; - private: /// Determines if ANSI mode is supported for casting from fromType to toType. /// TODO: Remove this function once all cast operations support ANSI mode. /// @param fromType The source type of the cast @@ -62,4 +63,8 @@ class SparkTryCastCallToSpecialForm : public exec::TryCastCallToSpecialForm { bool trackCpuUsage, const core::QueryConfig& config) override; }; + +void registerSparkCastModeSpecialForms( + const std::string& ansiCastName, + const std::string& legacyCastName); } // namespace facebook::velox::functions::sparksql