From 47ac097f61c0185cd5a0674528020a65917b7b90 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 20 Apr 2023 16:06:28 +0800 Subject: [PATCH] [SPARK-43207][CONNECT] Add helper functions to extract value from literal expression ### What changes were proposed in this pull request? Add helper functions for extract value from literal expression ### Why are the changes needed? some logic should be reused ### Does this PR introduce _any_ user-facing change? no, dev-only ### How was this patch tested? existing UTs Closes #40863 from zhengruifeng/connect_helper. Lead-authored-by: Ruifeng Zheng Co-authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../connect/planner/SparkConnectPlanner.scala | 123 ++++++------------ 1 file changed, 37 insertions(+), 86 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5f39fcd17f78e..e4522cea74735 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1240,83 +1240,49 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformUnregisteredFunction( fun: proto.Expression.UnresolvedFunction): Option[Expression] = { fun.getFunctionName match { - case "product" => - if (fun.getArgumentsCount != 1) { - throw InvalidPlanInput("Product requires single child expression") - } + case "product" if fun.getArgumentsCount == 1 => Some( aggregate .Product(transformExpression(fun.getArgumentsList.asScala.head)) .toAggregateExpression()) - case "when" => - if (fun.getArgumentsCount == 0) { - throw InvalidPlanInput("CaseWhen requires at least one child expression") - } + case "when" if fun.getArgumentsCount > 0 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) Some(CaseWhen.createFromParser(children)) - case "in" => - if (fun.getArgumentsCount == 0) { - throw InvalidPlanInput("In requires at least one child expression") - } + case "in" if fun.getArgumentsCount > 0 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) Some(In(children.head, children.tail)) case "nth_value" if fun.getArgumentsCount == 3 => // NthValue does not have a constructor which accepts Expression typed 'ignoreNulls' val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - val ignoreNulls = children.last match { - case Literal(bool: Boolean, BooleanType) => bool - case other => - throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, but got $other") - } + val ignoreNulls = extractBoolean(children(2), "ignoreNulls") Some(NthValue(children(0), children(1), ignoreNulls)) case "lag" if fun.getArgumentsCount == 4 => // Lag does not have a constructor which accepts Expression typed 'ignoreNulls' val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - val ignoreNulls = children.last match { - case Literal(bool: Boolean, BooleanType) => bool - case other => - throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, but got $other") - } + val ignoreNulls = extractBoolean(children(3), "ignoreNulls") Some(Lag(children.head, children(1), children(2), ignoreNulls)) case "lead" if fun.getArgumentsCount == 4 => // Lead does not have a constructor which accepts Expression typed 'ignoreNulls' val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - val ignoreNulls = children.last match { - case Literal(bool: Boolean, BooleanType) => bool - case other => - throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, but got $other") - } + val ignoreNulls = extractBoolean(children(3), "ignoreNulls") Some(Lead(children.head, children(1), children(2), ignoreNulls)) - case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => + case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) val timeCol = children.head - val args = children.tail.map { - case Literal(s, StringType) if s != null => s.toString - case other => - throw InvalidPlanInput( - s"windowDuration,slideDuration,startTime should be literal strings, but got $other") + val windowDuration = extractString(children(1), "windowDuration") + var slideDuration = windowDuration + if (fun.getArgumentsCount >= 3) { + slideDuration = extractString(children(2), "slideDuration") } - var windowDuration: String = null - var slideDuration: String = null - var startTime: String = null - if (args.length == 3) { - windowDuration = args(0) - slideDuration = args(1) - startTime = args(2) - } else if (args.length == 2) { - windowDuration = args(0) - slideDuration = args(1) - startTime = "0 second" - } else { - windowDuration = args(0) - slideDuration = args(0) - startTime = "0 second" + var startTime = "0 second" + if (fun.getArgumentsCount == 4) { + startTime = extractString(children(3), "startTime") } Some( Alias(TimeWindow(timeCol, windowDuration, slideDuration, startTime), "window")( @@ -1373,20 +1339,10 @@ class SparkConnectPlanner(val session: SparkSession) { } if (schema != null) { - val options = if (children.length == 3) { - // ExprUtils.convertToMapData requires the options to be resolved CreateMap, - // but the options here is not resolved yet: UnresolvedFunction("map", ...) - children(2) match { - case UnresolvedFunction(Seq("map"), arguments, _, _, _) => - ExprUtils.convertToMapData(CreateMap(arguments)) - case other => - throw InvalidPlanInput( - s"Options in from_json should be created by map, but got $other") - } - } else { - Map.empty[String, String] + var options = Map.empty[String, String] + if (children.length == 3) { + options = extractMapData(children(2), "Options") } - Some( JsonToStructs( schema = CharVarcharUtils.failIfHasCharVarchar(schema), @@ -1399,21 +1355,10 @@ class SparkConnectPlanner(val session: SparkSession) { // Avro-specific functions case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - val jsonFormatSchema = children(1) match { - case Literal(s, StringType) if s != null => s.toString - case other => - throw InvalidPlanInput( - s"jsonFormatSchema in from_avro should be a literal string, but got $other") - } + val jsonFormatSchema = extractString(children(1), "jsonFormatSchema") var options = Map.empty[String, String] if (fun.getArgumentsCount == 3) { - children(2) match { - case UnresolvedFunction(Seq("map"), arguments, _, _, _) => - options = ExprUtils.convertToMapData(CreateMap(arguments)) - case other => - throw InvalidPlanInput( - s"Options in from_json should be created by map, but got $other") - } + options = extractMapData(children(2), "Options") } Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options)) @@ -1421,12 +1366,7 @@ class SparkConnectPlanner(val session: SparkSession) { val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) var jsonFormatSchema = Option.empty[String] if (fun.getArgumentsCount == 2) { - children(1) match { - case Literal(s, StringType) if s != null => jsonFormatSchema = Some(s.toString) - case other => - throw InvalidPlanInput( - s"jsonFormatSchema in to_avro should be a literal string, but got $other") - } + jsonFormatSchema = Some(extractString(children(1), "jsonFormatSchema")) } Some(CatalystDataToAvro(children.head, jsonFormatSchema)) @@ -1437,12 +1377,7 @@ class SparkConnectPlanner(val session: SparkSession) { // ML-specific functions case "vector_to_array" if fun.getArgumentsCount == 2 => val expr = transformExpression(fun.getArguments(0)) - val dtype = transformExpression(fun.getArguments(1)) match { - case Literal(s, StringType) if s != null => s.toString - case other => - throw InvalidPlanInput( - s"dtype in vector_to_array should be a literal string, but got $other") - } + val dtype = extractString(transformExpression(fun.getArguments(1)), "dtype") dtype match { case "float64" => Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, Seq(expr))) @@ -1479,6 +1414,22 @@ class SparkConnectPlanner(val session: SparkSession) { udfDeterministic = f.deterministic) } + private def extractBoolean(expr: Expression, field: String): Boolean = expr match { + case Literal(bool: Boolean, BooleanType) => bool + case other => throw InvalidPlanInput(s"$field should be a literal boolean, but got $other") + } + + private def extractString(expr: Expression, field: String): String = expr match { + case Literal(s, StringType) if s != null => s.toString + case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other") + } + + private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match { + case map: CreateMap => ExprUtils.convertToMapData(map) + case UnresolvedFunction(Seq("map"), args, _, _, _) => extractMapData(CreateMap(args), field) + case other => throw InvalidPlanInput(s"$field should be created by map, but got $other") + } + private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { if (alias.getNameCount == 1) { val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {