Skip to content

Commit d16e92d

Browse files
shujingyang-dbueshin
authored andcommitted
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs
### What changes were proposed in this pull request? Support return type coercion for Arrow Python UDTFs by doing `arrow_cast` by default ### Why are the changes needed? Consistent behavior across Arrow UDFs and Arrow UDTFs ### Does this PR introduce _any_ user-facing change? No, Arrow UDTF is not a public API yet ### How was this patch tested? New and existing UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #52140 from shujingyang-db/arrow-udtf-type-corerion. Lead-authored-by: Shujing Yang <[email protected]> Co-authored-by: Shujing Yang <[email protected]> Signed-off-by: Takuya Ueshin <[email protected]>
1 parent b633ad3 commit d16e92d

File tree

4 files changed

+167
-14
lines changed

4 files changed

+167
-14
lines changed

python/pyspark/errors/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,11 @@
967967
"Column names of the returned pyarrow.Table do not match specified schema.<missing><extra>"
968968
]
969969
},
970+
"RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF": {
971+
"message": [
972+
"Column names of the returned pyarrow.Table or pyarrow.RecordBatch do not match specified schema. Expected: <expected> Actual: <actual>"
973+
]
974+
},
970975
"RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF": {
971976
"message": [
972977
"Column names of the returned pandas.DataFrame do not match specified schema.<missing><extra>"

python/pyspark/sql/pandas/serializers.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,68 @@ def load_stream(self, stream):
227227
result_batches.append(batch.column(i))
228228
yield result_batches
229229

230+
def _create_array(self, arr, arrow_type):
231+
import pyarrow as pa
232+
233+
assert isinstance(arr, pa.Array)
234+
assert isinstance(arrow_type, pa.DataType)
235+
if arr.type == arrow_type:
236+
return arr
237+
else:
238+
try:
239+
# when safe is True, the cast will fail if there's a overflow or other
240+
# unsafe conversion.
241+
# RecordBatch.cast(...) isn't used as minimum PyArrow version
242+
# required for RecordBatch.cast(...) is v16.0
243+
return arr.cast(target_type=arrow_type, safe=True)
244+
except (pa.ArrowInvalid, pa.ArrowTypeError):
245+
raise PySparkRuntimeError(
246+
errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",
247+
messageParameters={
248+
"expected": str(arrow_type),
249+
"actual": str(arr.type),
250+
},
251+
)
252+
253+
def dump_stream(self, iterator, stream):
254+
"""
255+
Override to handle type coercion for ArrowUDTF outputs.
256+
ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples.
257+
"""
258+
import pyarrow as pa
259+
260+
def apply_type_coercion():
261+
for batch, arrow_return_type in iterator:
262+
assert isinstance(
263+
arrow_return_type, pa.StructType
264+
), f"Expected pa.StructType, got {type(arrow_return_type)}"
265+
266+
# Handle empty struct case specially
267+
if batch.num_columns == 0:
268+
coerced_batch = batch # skip type coercion
269+
else:
270+
expected_field_names = arrow_return_type.names
271+
actual_field_names = batch.schema.names
272+
273+
if expected_field_names != actual_field_names:
274+
raise PySparkTypeError(
275+
"Target schema's field names are not matching the record batch's "
276+
"field names. "
277+
f"Expected: {expected_field_names}, but got: {actual_field_names}."
278+
)
279+
280+
coerced_arrays = []
281+
for i, field in enumerate(arrow_return_type):
282+
original_array = batch.column(i)
283+
coerced_array = self._create_array(original_array, field.type)
284+
coerced_arrays.append(coerced_array)
285+
coerced_batch = pa.RecordBatch.from_arrays(
286+
coerced_arrays, names=arrow_return_type.names
287+
)
288+
yield coerced_batch, arrow_return_type
289+
290+
return super().dump_stream(apply_type_coercion(), stream)
291+
230292

231293
class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
232294
"""

python/pyspark/sql/tests/arrow/test_arrow_udtf.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ def eval(self) -> Iterator["pa.Table"]:
189189
)
190190
yield result_table
191191

192-
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
192+
with self.assertRaisesRegex(
193+
PythonException,
194+
"Target schema's field names are not matching the record batch's field names",
195+
):
193196
result_df = MismatchedSchemaUDTF()
194197
result_df.collect()
195198

@@ -330,9 +333,10 @@ def eval(self) -> Iterator["pa.Table"]:
330333
)
331334
yield result_table
332335

333-
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
334-
result_df = LongToIntUDTF()
335-
result_df.collect()
336+
# Should succeed with automatic coercion
337+
result_df = LongToIntUDTF()
338+
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
339+
assertDataFrameEqual(result_df, expected_df)
336340

337341
def test_arrow_udtf_type_coercion_string_to_int(self):
338342
@arrow_udtf(returnType="id int")
@@ -341,15 +345,103 @@ def eval(self) -> Iterator["pa.Table"]:
341345
# Return string values that cannot be coerced to int
342346
result_table = pa.table(
343347
{
344-
"id": pa.array(["abc", "def", "xyz"], type=pa.string()),
348+
"id": pa.array(["1", "2", "xyz"], type=pa.string()),
345349
}
346350
)
347351
yield result_table
348352

349-
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
353+
# Should fail with Arrow cast exception since string cannot be cast to int
354+
with self.assertRaisesRegex(
355+
PythonException,
356+
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF\\] "
357+
"Column names of the returned pyarrow.Table or pyarrow.RecordBatch do not match "
358+
"specified schema. Expected: int32 Actual: string",
359+
):
350360
result_df = StringToIntUDTF()
351361
result_df.collect()
352362

363+
def test_arrow_udtf_type_coercion_string_to_int_safe(self):
364+
@arrow_udtf(returnType="id int")
365+
class StringToIntUDTF:
366+
def eval(self) -> Iterator["pa.Table"]:
367+
result_table = pa.table(
368+
{
369+
"id": pa.array(["1", "2", "3"], type=pa.string()),
370+
}
371+
)
372+
yield result_table
373+
374+
result_df = StringToIntUDTF()
375+
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
376+
assertDataFrameEqual(result_df, expected_df)
377+
378+
def test_arrow_udtf_type_corecion_int64_to_int32_safe(self):
379+
@arrow_udtf(returnType="id int")
380+
class Int64ToInt32UDTF:
381+
def eval(self) -> Iterator["pa.Table"]:
382+
result_table = pa.table(
383+
{
384+
"id": pa.array([1, 2, 3], type=pa.int64()), # long values
385+
}
386+
)
387+
yield result_table
388+
389+
result_df = Int64ToInt32UDTF()
390+
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
391+
assertDataFrameEqual(result_df, expected_df)
392+
393+
def test_return_type_coercion_success(self):
394+
@arrow_udtf(returnType="value int")
395+
class CoercionSuccessUDTF:
396+
def eval(self) -> Iterator["pa.Table"]:
397+
result_table = pa.table(
398+
{
399+
"value": pa.array([10, 20, 30], type=pa.int64()), # long -> int coercion
400+
}
401+
)
402+
yield result_table
403+
404+
result_df = CoercionSuccessUDTF()
405+
expected_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int")
406+
assertDataFrameEqual(result_df, expected_df)
407+
408+
def test_return_type_coercion_overflow(self):
409+
@arrow_udtf(returnType="value int")
410+
class CoercionOverflowUDTF:
411+
def eval(self) -> Iterator["pa.Table"]:
412+
# Return values that will cause overflow when casting long to int
413+
result_table = pa.table(
414+
{
415+
"value": pa.array([2147483647 + 1], type=pa.int64()), # int32 max + 1
416+
}
417+
)
418+
yield result_table
419+
420+
# Should fail with PyArrow overflow exception
421+
with self.assertRaises(Exception):
422+
result_df = CoercionOverflowUDTF()
423+
result_df.collect()
424+
425+
def test_return_type_coercion_multiple_columns(self):
426+
@arrow_udtf(returnType="id int, price float")
427+
class MultipleColumnCoercionUDTF:
428+
def eval(self) -> Iterator["pa.Table"]:
429+
result_table = pa.table(
430+
{
431+
"id": pa.array([1, 2, 3], type=pa.int64()), # long -> int coercion
432+
"price": pa.array(
433+
[10.5, 20.7, 30.9], type=pa.float64()
434+
), # double -> float coercion
435+
}
436+
)
437+
yield result_table
438+
439+
result_df = MultipleColumnCoercionUDTF()
440+
expected_df = self.spark.createDataFrame(
441+
[(1, 10.5), (2, 20.7), (3, 30.9)], "id int, price float"
442+
)
443+
assertDataFrameEqual(result_df, expected_df)
444+
353445
def test_arrow_udtf_with_empty_column_result(self):
354446
@arrow_udtf(returnType=StructType())
355447
class EmptyResultUDTF:

python/pyspark/worker.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,14 +1970,8 @@ def verify_result(result):
19701970
},
19711971
)
19721972

1973-
# Verify the type and the schema of the result.
1974-
verify_arrow_result(
1975-
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
1976-
assign_cols_by_name=False,
1977-
expected_cols_and_types=[
1978-
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
1979-
],
1980-
)
1973+
# We verify the type of the result and do type corerion
1974+
# in the serializer
19811975
return result
19821976

19831977
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.

0 commit comments

Comments
 (0)