Skip to content

Commit dbd765d

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-53544][PYTHON] Support complex types on observations
### What changes were proposed in this pull request? Supports complex types on observations. ### Why are the changes needed? The observations didn't support complex types. For example: ```py >>> observation = Observation("struct") >>> df = spark.range(10).observe( ... observation, ... F.struct(F.count(F.lit(1)).alias("rows"), F.max("id").alias("maxid")).alias("struct"), ... ) ``` - classic ```py >>> df.collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=3), Row(id=4), Row(id=5), Row(id=6), Row(id=7), Row(id=8), Row(id=9)] >>> observation.get {'struct': JavaObject id=o61} ``` - connect ```py >>> df.collect() Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkTypeError: [UNSUPPORTED_LITERAL] Unsupported Literal 'struct { ... ``` ### Does this PR introduce _any_ user-facing change? Yes, complex types are available on observations. ```py >>> df.collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=3), Row(id=4), Row(id=5), Row(id=6), Row(id=7), Row(id=8), Row(id=9)] >>> >>> observation.get {'struct': Row(rows=10, maxid=9)} ``` ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52321 from ueshin/issues/SPARK-53544/complex_observation. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 8c49165 commit dbd765d

File tree

4 files changed

+94
-22
lines changed

4 files changed

+94
-22
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from pyspark.serializers import CloudPickleSerializer
4141
from pyspark.sql.types import (
42+
_create_row,
4243
_from_numpy_type,
4344
DateType,
4445
ArrayType,
@@ -58,6 +59,8 @@
5859
TimestampType,
5960
TimestampNTZType,
6061
DayTimeIntervalType,
62+
MapType,
63+
StructType,
6164
)
6265

6366
import pyspark.sql.connect.proto as proto
@@ -441,6 +444,29 @@ def _to_value(
441444
assert isinstance(dataType, ArrayType)
442445
assert elementType == dataType.elementType
443446
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
447+
elif literal.HasField("map"):
448+
keyType = proto_schema_to_pyspark_data_type(literal.map.key_type)
449+
valueType = proto_schema_to_pyspark_data_type(literal.map.value_type)
450+
if dataType is not None:
451+
assert isinstance(dataType, MapType)
452+
assert keyType == dataType.keyType
453+
assert valueType == dataType.valueType
454+
return {
455+
LiteralExpression._to_value(k, keyType): LiteralExpression._to_value(v, valueType)
456+
for k, v in zip(literal.map.keys, literal.map.values)
457+
}
458+
elif literal.HasField("struct"):
459+
struct_type = cast(
460+
StructType, proto_schema_to_pyspark_data_type(literal.struct.struct_type)
461+
)
462+
if dataType is not None:
463+
assert isinstance(dataType, StructType)
464+
assert struct_type == dataType
465+
values = [
466+
LiteralExpression._to_value(v, f.dataType)
467+
for v, f in zip(literal.struct.elements, struct_type.fields)
468+
]
469+
return _create_row(struct_type.names, values)
444470

445471
raise PySparkTypeError(
446472
errorClass="UNSUPPORTED_LITERAL",

python/pyspark/sql/observation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import Any, Dict, Optional, TYPE_CHECKING
1919

2020
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkAssertionError
21+
from pyspark.serializers import CPickleSerializer
22+
from pyspark.sql import Row
2123
from pyspark.sql.column import Column
2224
from pyspark.sql.dataframe import DataFrame
2325
from pyspark.sql.utils import is_remote
@@ -144,9 +146,11 @@ def get(self) -> Dict[str, Any]:
144146
if self._jo is None:
145147
raise PySparkAssertionError(errorClass="NO_OBSERVE_BEFORE_GET", messageParameters={})
146148

147-
jmap = self._jo.getAsJava()
148-
# return a pure Python dict, not jmap which is a py4j JavaMap
149-
return {k: v for k, v in jmap.items()}
149+
assert self._jvm is not None
150+
utils = getattr(self._jvm, "org.apache.spark.sql.api.python.PythonSQLUtils")
151+
jrow = self._jo.getRow()
152+
row: Row = CPickleSerializer().loads(utils.toPyRow(jrow))
153+
return row.asDict(recursive=False)
150154

151155

152156
def _test() -> None:

python/pyspark/sql/tests/test_observation.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,19 @@
1818
import time
1919
import unittest
2020

21-
from pyspark.sql import Row
22-
from pyspark.sql.functions import col, lit, count, sum, mean
21+
from pyspark.sql import Row, Observation, functions as F
2322
from pyspark.errors import (
2423
PySparkAssertionError,
2524
PySparkTypeError,
2625
PySparkValueError,
2726
)
2827
from pyspark.testing.sqlutils import ReusedSQLTestCase
28+
from pyspark.testing.utils import assertDataFrameEqual
2929

3030

3131
class DataFrameObservationTestsMixin:
3232
def test_observe(self):
3333
# SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
34-
from pyspark.sql import Observation
35-
3634
df = self.spark.createDataFrame(
3735
[
3836
(1, 1.0, "one"),
@@ -58,11 +56,11 @@ def test_observe(self):
5856
df.orderBy("id")
5957
.observe(
6058
named_observation,
61-
count(lit(1)).alias("cnt"),
62-
sum(col("id")).alias("sum"),
63-
mean(col("val")).alias("mean"),
59+
F.count(F.lit(1)).alias("cnt"),
60+
F.sum(F.col("id")).alias("sum"),
61+
F.mean(F.col("val")).alias("mean"),
6462
)
65-
.observe(unnamed_observation, count(lit(1)).alias("rows"))
63+
.observe(unnamed_observation, F.count(F.lit(1)).alias("rows"))
6664
)
6765

6866
# test that observe works transparently
@@ -81,7 +79,7 @@ def test_observe(self):
8179
self.assertEqual(unnamed_observation.get, dict(rows=3))
8280

8381
with self.assertRaises(PySparkAssertionError) as pe:
84-
df.observe(named_observation, count(lit(1)).alias("count"))
82+
df.observe(named_observation, F.count(F.lit(1)).alias("count"))
8583

8684
self.check_error(
8785
exception=pe.exception,
@@ -106,7 +104,7 @@ def test_observe(self):
106104
)
107105

108106
# dataframe.observe requires non-None Columns
109-
for args in [(None,), ("id",), (lit(1), None), (lit(1), "id")]:
107+
for args in [(None,), ("id",), (F.lit(1), None), (F.lit(1), "id")]:
110108
with self.subTest(args=args):
111109
with self.assertRaises(PySparkTypeError) as pe:
112110
df.observe(Observation(), *args)
@@ -140,7 +138,9 @@ def onQueryTerminated(self, event):
140138
self.spark.streams.addListener(TestListener())
141139

142140
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
143-
df = df.observe("metric", count(lit(1)).alias("cnt"), sum(col("value")).alias("sum"))
141+
df = df.observe(
142+
"metric", F.count(F.lit(1)).alias("cnt"), F.sum(F.col("value")).alias("sum")
143+
)
144144
q = df.writeStream.format("noop").queryName("test").start()
145145
self.assertTrue(q.isActive)
146146
time.sleep(10)
@@ -157,15 +157,13 @@ def onQueryTerminated(self, event):
157157

158158
def test_observe_with_same_name_on_different_dataframe(self):
159159
# SPARK-45656: named observations with the same name on different datasets
160-
from pyspark.sql import Observation
161-
162160
observation1 = Observation("named")
163161
df1 = self.spark.range(50)
164-
observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt"))
162+
observed_df1 = df1.observe(observation1, F.count(F.lit(1)).alias("cnt"))
165163

166164
observation2 = Observation("named")
167165
df2 = self.spark.range(100)
168-
observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt"))
166+
observed_df2 = df2.observe(observation2, F.count(F.lit(1)).alias("cnt"))
169167

170168
observed_df1.collect()
171169
observed_df2.collect()
@@ -174,8 +172,6 @@ def test_observe_with_same_name_on_different_dataframe(self):
174172
self.assertEqual(observation2.get, dict(cnt=100))
175173

176174
def test_observe_on_commands(self):
177-
from pyspark.sql import Observation
178-
179175
df = self.spark.range(50)
180176

181177
test_table = "test_table"
@@ -190,10 +186,46 @@ def test_observe_on_commands(self):
190186
]:
191187
with self.subTest(command=command):
192188
observation = Observation()
193-
observed_df = df.observe(observation, count(lit(1)).alias("cnt"))
189+
observed_df = df.observe(observation, F.count(F.lit(1)).alias("cnt"))
194190
action(observed_df)
195191
self.assertEqual(observation.get, dict(cnt=50))
196192

193+
def test_observe_with_struct_type(self):
194+
observation = Observation("struct")
195+
196+
df = self.spark.range(10).observe(
197+
observation,
198+
F.struct(F.count(F.lit(1)).alias("rows"), F.max("id").alias("maxid")).alias("struct"),
199+
)
200+
201+
assertDataFrameEqual(df, [Row(id=id) for id in range(10)])
202+
203+
self.assertEqual(observation.get, {"struct": Row(rows=10, maxid=9)})
204+
205+
def test_observe_with_array_type(self):
206+
observation = Observation("array")
207+
208+
df = self.spark.range(10).observe(
209+
observation,
210+
F.array(F.count(F.lit(1))).alias("array"),
211+
)
212+
213+
assertDataFrameEqual(df, [Row(id=id) for id in range(10)])
214+
215+
self.assertEqual(observation.get, {"array": [10]})
216+
217+
def test_observe_with_map_type(self):
218+
observation = Observation("map")
219+
220+
df = self.spark.range(10).observe(
221+
observation,
222+
F.create_map(F.lit("count"), F.count(F.lit(1))).alias("map"),
223+
)
224+
225+
assertDataFrameEqual(df, [Row(id=id) for id in range(10)])
226+
227+
self.assertEqual(observation.get, {"map": {"count": 10}})
228+
197229

198230
class DataFrameObservationTests(
199231
DataFrameObservationTestsMixin,

sql/api/src/main/scala/org/apache/spark/sql/Observation.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class Observation(val name: String) {
7777
*/
7878
@throws[InterruptedException]
7979
def get: Map[String, Any] = {
80-
val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
80+
val row = getRow
8181
row.getValuesMap(row.schema.map(_.name))
8282
}
8383

@@ -134,6 +134,16 @@ class Observation(val name: String) {
134134
private[sql] def getRowOrEmpty: Option[Row] = {
135135
Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
136136
}
137+
138+
/**
139+
* Get the observed metrics as a Row.
140+
*
141+
* @return
142+
* the observed metrics as a `Row`.
143+
*/
144+
private[sql] def getRow: Row = {
145+
SparkThreadUtils.awaitResult(future, Duration.Inf)
146+
}
137147
}
138148

139149
/**

0 commit comments

Comments
 (0)