Skip to content

Commit

Permalink
[SPARK-43207][CONNECT] Add helper functions to extract value from lit…
Browse files Browse the repository at this point in the history
…eral expression

### What changes were proposed in this pull request?
Add helper functions for extract value from literal expression

### Why are the changes needed?
some logic should be reused

### Does this PR introduce _any_ user-facing change?
no, dev-only

### How was this patch tested?
existing UTs

Closes apache#40863 from zhengruifeng/connect_helper.

Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng and zhengruifeng committed Apr 20, 2023
1 parent 9e17731 commit 47ac097
Showing 1 changed file with 37 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1240,83 +1240,49 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformUnregisteredFunction(
fun: proto.Expression.UnresolvedFunction): Option[Expression] = {
fun.getFunctionName match {
case "product" =>
if (fun.getArgumentsCount != 1) {
throw InvalidPlanInput("Product requires single child expression")
}
case "product" if fun.getArgumentsCount == 1 =>
Some(
aggregate
.Product(transformExpression(fun.getArgumentsList.asScala.head))
.toAggregateExpression())

case "when" =>
if (fun.getArgumentsCount == 0) {
throw InvalidPlanInput("CaseWhen requires at least one child expression")
}
case "when" if fun.getArgumentsCount > 0 =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
Some(CaseWhen.createFromParser(children))

case "in" =>
if (fun.getArgumentsCount == 0) {
throw InvalidPlanInput("In requires at least one child expression")
}
case "in" if fun.getArgumentsCount > 0 =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
Some(In(children.head, children.tail))

case "nth_value" if fun.getArgumentsCount == 3 =>
// NthValue does not have a constructor which accepts Expression typed 'ignoreNulls'
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
val ignoreNulls = children.last match {
case Literal(bool: Boolean, BooleanType) => bool
case other =>
throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, but got $other")
}
val ignoreNulls = extractBoolean(children(2), "ignoreNulls")
Some(NthValue(children(0), children(1), ignoreNulls))

case "lag" if fun.getArgumentsCount == 4 =>
// Lag does not have a constructor which accepts Expression typed 'ignoreNulls'
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
val ignoreNulls = children.last match {
case Literal(bool: Boolean, BooleanType) => bool
case other =>
throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, but got $other")
}
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lag(children.head, children(1), children(2), ignoreNulls))

case "lead" if fun.getArgumentsCount == 4 =>
// Lead does not have a constructor which accepts Expression typed 'ignoreNulls'
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
val ignoreNulls = children.last match {
case Literal(bool: Boolean, BooleanType) => bool
case other =>
throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, but got $other")
}
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lead(children.head, children(1), children(2), ignoreNulls))

case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 =>
case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
val timeCol = children.head
val args = children.tail.map {
case Literal(s, StringType) if s != null => s.toString
case other =>
throw InvalidPlanInput(
s"windowDuration,slideDuration,startTime should be literal strings, but got $other")
val windowDuration = extractString(children(1), "windowDuration")
var slideDuration = windowDuration
if (fun.getArgumentsCount >= 3) {
slideDuration = extractString(children(2), "slideDuration")
}
var windowDuration: String = null
var slideDuration: String = null
var startTime: String = null
if (args.length == 3) {
windowDuration = args(0)
slideDuration = args(1)
startTime = args(2)
} else if (args.length == 2) {
windowDuration = args(0)
slideDuration = args(1)
startTime = "0 second"
} else {
windowDuration = args(0)
slideDuration = args(0)
startTime = "0 second"
var startTime = "0 second"
if (fun.getArgumentsCount == 4) {
startTime = extractString(children(3), "startTime")
}
Some(
Alias(TimeWindow(timeCol, windowDuration, slideDuration, startTime), "window")(
Expand Down Expand Up @@ -1373,20 +1339,10 @@ class SparkConnectPlanner(val session: SparkSession) {
}

if (schema != null) {
val options = if (children.length == 3) {
// ExprUtils.convertToMapData requires the options to be resolved CreateMap,
// but the options here is not resolved yet: UnresolvedFunction("map", ...)
children(2) match {
case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
ExprUtils.convertToMapData(CreateMap(arguments))
case other =>
throw InvalidPlanInput(
s"Options in from_json should be created by map, but got $other")
}
} else {
Map.empty[String, String]
var options = Map.empty[String, String]
if (children.length == 3) {
options = extractMapData(children(2), "Options")
}

Some(
JsonToStructs(
schema = CharVarcharUtils.failIfHasCharVarchar(schema),
Expand All @@ -1399,34 +1355,18 @@ class SparkConnectPlanner(val session: SparkSession) {
// Avro-specific functions
case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
val jsonFormatSchema = children(1) match {
case Literal(s, StringType) if s != null => s.toString
case other =>
throw InvalidPlanInput(
s"jsonFormatSchema in from_avro should be a literal string, but got $other")
}
val jsonFormatSchema = extractString(children(1), "jsonFormatSchema")
var options = Map.empty[String, String]
if (fun.getArgumentsCount == 3) {
children(2) match {
case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
options = ExprUtils.convertToMapData(CreateMap(arguments))
case other =>
throw InvalidPlanInput(
s"Options in from_json should be created by map, but got $other")
}
options = extractMapData(children(2), "Options")
}
Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options))

case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
var jsonFormatSchema = Option.empty[String]
if (fun.getArgumentsCount == 2) {
children(1) match {
case Literal(s, StringType) if s != null => jsonFormatSchema = Some(s.toString)
case other =>
throw InvalidPlanInput(
s"jsonFormatSchema in to_avro should be a literal string, but got $other")
}
jsonFormatSchema = Some(extractString(children(1), "jsonFormatSchema"))
}
Some(CatalystDataToAvro(children.head, jsonFormatSchema))

Expand All @@ -1437,12 +1377,7 @@ class SparkConnectPlanner(val session: SparkSession) {
// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
val expr = transformExpression(fun.getArguments(0))
val dtype = transformExpression(fun.getArguments(1)) match {
case Literal(s, StringType) if s != null => s.toString
case other =>
throw InvalidPlanInput(
s"dtype in vector_to_array should be a literal string, but got $other")
}
val dtype = extractString(transformExpression(fun.getArguments(1)), "dtype")
dtype match {
case "float64" =>
Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, Seq(expr)))
Expand Down Expand Up @@ -1479,6 +1414,22 @@ class SparkConnectPlanner(val session: SparkSession) {
udfDeterministic = f.deterministic)
}

private def extractBoolean(expr: Expression, field: String): Boolean = expr match {
case Literal(bool: Boolean, BooleanType) => bool
case other => throw InvalidPlanInput(s"$field should be a literal boolean, but got $other")
}

private def extractString(expr: Expression, field: String): String = expr match {
case Literal(s, StringType) if s != null => s.toString
case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other")
}

private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match {
case map: CreateMap => ExprUtils.convertToMapData(map)
case UnresolvedFunction(Seq("map"), args, _, _, _) => extractMapData(CreateMap(args), field)
case other => throw InvalidPlanInput(s"$field should be created by map, but got $other")
}

private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
if (alias.getNameCount == 1) {
val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {
Expand Down

0 comments on commit 47ac097

Please sign in to comment.