Skip to content

Commit

Permalink
[SPARK-43084][SS] Add applyInPandasWithState support for spark connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This change adds applyInPandasWithState support for Spark connect.
Example (try with local mode `./bin/pyspark --remote "local[*]"`):

```
>>> from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
>>> from pyspark.sql.types import (
...     LongType,
...     StringType,
...     StructType,
...     StructField,
...     Row,
... )
>>> import pandas as pd
>>> output_type = StructType(
...     [StructField("key", StringType()), StructField("countAsString", StringType())]
... )
>>> state_type = StructType([StructField("c", LongType())])
>>> def func(key, pdf_iter, state):
...     total_len = 0
...     for pdf in pdf_iter:
...         total_len += len(pdf)
...     state.update((total_len,))
...     yield pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]})
...
>>>
>>> input_path = "/Users/peng.zhong/tmp/applyInPandasWithState"
>>> df = spark.readStream.format("text").load(input_path)
>>> q = (
...       df.groupBy(df["value"])
...       .applyInPandasWithState(
...           func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout
...       )
...       .writeStream.queryName("this_query")
...       .format("memory")
...       .outputMode("update")
...       .start()
...   )
>>>
>>> q.status
{'message': 'Processing new data', 'isDataAvailable': True, 'isTriggerActive': True}
>>>
>>> spark.sql("select * from this_query").show()
+-----+-------------+
|  key|countAsString|
+-----+-------------+
|hello|            1|
| this|            1|
+-----+-------------+
```

### Why are the changes needed?

This change adds an API support for spark connect.

### Does this PR introduce _any_ user-facing change?

This change adds an API support for spark connect.

### How was this patch tested?

Manually tested.

Closes apache#40736 from pengzhon-db/connect_applyInPandasWithState.

Authored-by: pengzhon-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
pengzhon-db authored and HyukjinKwon committed Apr 18, 2023
1 parent 3941369 commit cbe94a1
Show file tree
Hide file tree
Showing 12 changed files with 408 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ message Relation {
GroupMap group_map = 31;
CoGroupMap co_group_map = 32;
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -840,6 +841,29 @@ message CoGroupMap {
CommonInlineUserDefinedFunction func = 5;
}

message ApplyInPandasWithState {
// (Required) Input relation for applyInPandasWithState.
Relation input = 1;

// (Required) Expressions for grouping keys.
repeated Expression grouping_expressions = 2;

// (Required) Input user-defined function.
CommonInlineUserDefinedFunction func = 3;

// (Required) Schema for the output DataFrame.
string output_schema = 4;

// (Required) Schema for the state.
string state_schema = 5;

// (Required) The output mode of the function.
string output_mode = 6;

// (Required) Timeout configuration for groups that do not receive data for a while.
string timeout_conf = 7;
}

// Collect arbitrary (named) metrics from a dataset.
message CollectMetrics {
// (Required) The input relation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformGroupMap(rel.getGroupMap)
case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
Expand Down Expand Up @@ -583,6 +585,27 @@ class SparkConnectPlanner(val session: SparkSession) {
input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
}

private def transformApplyInPandasWithState(rel: proto.ApplyInPandasWithState): LogicalPlan = {
val pythonUdf = transformPythonUDF(rel.getFunc)
val cols =
rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))

val outputSchema = parseSchema(rel.getOutputSchema)

val stateSchema = parseSchema(rel.getStateSchema)

Dataset
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)
.applyInPandasWithState(
pythonUdf,
outputSchema,
stateSchema,
rel.getOutputMode,
rel.getTimeoutConf)
.logicalPlan
}

private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
"pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
# ml doctests
"pyspark.ml.connect.functions",
# ml unittests
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import DataType
from pyspark.sql.streaming.state import GroupState


ColumnOrName = Union[Column, str]
Expand Down Expand Up @@ -63,6 +64,10 @@

PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], DataFrameLike]

PandasGroupedMapFunctionWithState = Callable[
[Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
]


class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
Expand Down
46 changes: 44 additions & 2 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.types import NumericType
from pyspark.sql.types import StructType

import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.column import Column
Expand All @@ -47,6 +48,7 @@
PandasGroupedMapFunction,
GroupedMapPandasUserDefinedFunction,
PandasCogroupedMapFunction,
PandasGroupedMapFunctionWithState,
)
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.types import StructType
Expand Down Expand Up @@ -262,8 +264,48 @@ def applyInPandas(

applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__

def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("applyInPandasWithState() is not implemented.")
def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

udf_obj = UserDefinedFunction(
func,
returnType=outputStructType,
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)

output_schema: str = (
outputStructType.json()
if isinstance(outputStructType, StructType)
else outputStructType
)

state_schema: str = (
stateStructType.json() if isinstance(stateStructType, StructType) else stateStructType
)

return DataFrame.withPlan(
plan.ApplyInPandasWithState(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=output_schema,
state_schema=state_schema,
output_mode=outputMode,
timeout_conf=timeoutConf,
cols=self._df.columns,
),
session=self._df._session,
)

applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
return PandasCogroupedOps(self, other)
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,45 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class ApplyInPandasWithState(LogicalPlan):
"""Logical plan object for a applyInPandasWithState."""

def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
output_schema: str,
state_schema: str,
output_mode: str,
timeout_conf: str,
cols: List[str],
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)

super().__init__(child)
self._grouping_cols = grouping_cols
self._func = function._build_common_inline_user_defined_function(*cols)
self._output_schema = output_schema
self._state_schema = state_schema
self._output_mode = output_mode
self._timeout_conf = timeout_conf

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.apply_in_pandas_with_state.input.CopyFrom(self._child.plan(session))
plan.apply_in_pandas_with_state.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
plan.apply_in_pandas_with_state.func.CopyFrom(self._func.to_plan_udf(session))
plan.apply_in_pandas_with_state.output_schema = self._output_schema
plan.apply_in_pandas_with_state.state_schema = self._state_schema
plan.apply_in_pandas_with_state.output_mode = self._output_mode
plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf
return plan


class CachedRelation(LogicalPlan):
def __init__(self, plan: proto.Relation) -> None:
super(CachedRelation, self).__init__(None)
Expand Down
258 changes: 136 additions & 122 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

80 changes: 80 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Relation(google.protobuf.message.Message):
GROUP_MAP_FIELD_NUMBER: builtins.int
CO_GROUP_MAP_FIELD_NUMBER: builtins.int
WITH_WATERMARK_FIELD_NUMBER: builtins.int
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -176,6 +177,8 @@ class Relation(google.protobuf.message.Message):
@property
def with_watermark(self) -> global___WithWatermark: ...
@property
def apply_in_pandas_with_state(self) -> global___ApplyInPandasWithState: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -245,6 +248,7 @@ class Relation(google.protobuf.message.Message):
group_map: global___GroupMap | None = ...,
co_group_map: global___CoGroupMap | None = ...,
with_watermark: global___WithWatermark | None = ...,
apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand All @@ -265,6 +269,8 @@ class Relation(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"aggregate",
b"aggregate",
"apply_in_pandas_with_state",
b"apply_in_pandas_with_state",
"approx_quantile",
b"approx_quantile",
"catalog",
Expand Down Expand Up @@ -366,6 +372,8 @@ class Relation(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"aggregate",
b"aggregate",
"apply_in_pandas_with_state",
b"apply_in_pandas_with_state",
"approx_quantile",
b"approx_quantile",
"catalog",
Expand Down Expand Up @@ -497,6 +505,7 @@ class Relation(google.protobuf.message.Message):
"group_map",
"co_group_map",
"with_watermark",
"apply_in_pandas_with_state",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -2980,6 +2989,77 @@ class CoGroupMap(google.protobuf.message.Message):

global___CoGroupMap = CoGroupMap

class ApplyInPandasWithState(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
FUNC_FIELD_NUMBER: builtins.int
OUTPUT_SCHEMA_FIELD_NUMBER: builtins.int
STATE_SCHEMA_FIELD_NUMBER: builtins.int
OUTPUT_MODE_FIELD_NUMBER: builtins.int
TIMEOUT_CONF_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for applyInPandasWithState."""
@property
def grouping_expressions(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Required) Expressions for grouping keys."""
@property
def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
"""(Required) Input user-defined function."""
output_schema: builtins.str
"""(Required) Schema for the output DataFrame."""
state_schema: builtins.str
"""(Required) Schema for the state."""
output_mode: builtins.str
"""(Required) The output mode of the function."""
timeout_conf: builtins.str
"""(Required) Timeout configuration for groups that do not receive data for a while."""
def __init__(
self,
*,
input: global___Relation | None = ...,
grouping_expressions: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
| None = ...,
output_schema: builtins.str = ...,
state_schema: builtins.str = ...,
output_mode: builtins.str = ...,
timeout_conf: builtins.str = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"func",
b"func",
"grouping_expressions",
b"grouping_expressions",
"input",
b"input",
"output_mode",
b"output_mode",
"output_schema",
b"output_schema",
"state_schema",
b"state_schema",
"timeout_conf",
b"timeout_conf",
],
) -> None: ...

global___ApplyInPandasWithState = ApplyInPandasWithState

class CollectMetrics(google.protobuf.message.Message):
"""Collect arbitrary (named) metrics from a dataset."""

Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def applyInPandasWithState(
.. versionadded:: 3.4.0
.. versionchanged:: 3.5.0
Supports Spark Connect.
Parameters
----------
func : function
Expand Down
7 changes: 0 additions & 7 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2928,13 +2928,6 @@ def test_unsupported_functions(self):
with self.assertRaises(NotImplementedError):
getattr(df, f)()

def test_unsupported_group_functions(self):
# SPARK-41927: Disable unsupported functions.
cg = self.connect.read.table(self.tbl_name).groupBy("id")
for f in ("applyInPandasWithState",):
with self.assertRaises(NotImplementedError):
getattr(cg, f)()

def test_unsupported_session_functions(self):
# SPARK-41934: Disable unsupported functions.

Expand Down
Loading

0 comments on commit cbe94a1

Please sign in to comment.