@@ -189,7 +189,10 @@ def eval(self) -> Iterator["pa.Table"]:
189
189
)
190
190
yield result_table
191
191
192
- with self .assertRaisesRegex (PythonException , "Schema at index 0 was different" ):
192
+ with self .assertRaisesRegex (
193
+ PythonException ,
194
+ "Target schema's field names are not matching the record batch's field names" ,
195
+ ):
193
196
result_df = MismatchedSchemaUDTF ()
194
197
result_df .collect ()
195
198
@@ -330,9 +333,10 @@ def eval(self) -> Iterator["pa.Table"]:
330
333
)
331
334
yield result_table
332
335
333
- with self .assertRaisesRegex (PythonException , "Schema at index 0 was different" ):
334
- result_df = LongToIntUDTF ()
335
- result_df .collect ()
336
+ # Should succeed with automatic coercion
337
+ result_df = LongToIntUDTF ()
338
+ expected_df = self .spark .createDataFrame ([(1 ,), (2 ,), (3 ,)], "id int" )
339
+ assertDataFrameEqual (result_df , expected_df )
336
340
337
341
def test_arrow_udtf_type_coercion_string_to_int (self ):
338
342
@arrow_udtf (returnType = "id int" )
@@ -341,15 +345,103 @@ def eval(self) -> Iterator["pa.Table"]:
341
345
# Return string values that cannot be coerced to int
342
346
result_table = pa .table (
343
347
{
344
- "id" : pa .array (["abc " , "def " , "xyz" ], type = pa .string ()),
348
+ "id" : pa .array (["1 " , "2 " , "xyz" ], type = pa .string ()),
345
349
}
346
350
)
347
351
yield result_table
348
352
349
- with self .assertRaisesRegex (PythonException , "Schema at index 0 was different" ):
353
+ # Should fail with Arrow cast exception since string cannot be cast to int
354
+ with self .assertRaisesRegex (
355
+ PythonException ,
356
+ "PySparkRuntimeError: \\ [RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF\\ ] "
357
+ "Column names of the returned pyarrow.Table or pyarrow.RecordBatch do not match "
358
+ "specified schema. Expected: int32 Actual: string" ,
359
+ ):
350
360
result_df = StringToIntUDTF ()
351
361
result_df .collect ()
352
362
363
+ def test_arrow_udtf_type_coercion_string_to_int_safe (self ):
364
+ @arrow_udtf (returnType = "id int" )
365
+ class StringToIntUDTF :
366
+ def eval (self ) -> Iterator ["pa.Table" ]:
367
+ result_table = pa .table (
368
+ {
369
+ "id" : pa .array (["1" , "2" , "3" ], type = pa .string ()),
370
+ }
371
+ )
372
+ yield result_table
373
+
374
+ result_df = StringToIntUDTF ()
375
+ expected_df = self .spark .createDataFrame ([(1 ,), (2 ,), (3 ,)], "id int" )
376
+ assertDataFrameEqual (result_df , expected_df )
377
+
378
+ def test_arrow_udtf_type_corecion_int64_to_int32_safe (self ):
379
+ @arrow_udtf (returnType = "id int" )
380
+ class Int64ToInt32UDTF :
381
+ def eval (self ) -> Iterator ["pa.Table" ]:
382
+ result_table = pa .table (
383
+ {
384
+ "id" : pa .array ([1 , 2 , 3 ], type = pa .int64 ()), # long values
385
+ }
386
+ )
387
+ yield result_table
388
+
389
+ result_df = Int64ToInt32UDTF ()
390
+ expected_df = self .spark .createDataFrame ([(1 ,), (2 ,), (3 ,)], "id int" )
391
+ assertDataFrameEqual (result_df , expected_df )
392
+
393
+ def test_return_type_coercion_success (self ):
394
+ @arrow_udtf (returnType = "value int" )
395
+ class CoercionSuccessUDTF :
396
+ def eval (self ) -> Iterator ["pa.Table" ]:
397
+ result_table = pa .table (
398
+ {
399
+ "value" : pa .array ([10 , 20 , 30 ], type = pa .int64 ()), # long -> int coercion
400
+ }
401
+ )
402
+ yield result_table
403
+
404
+ result_df = CoercionSuccessUDTF ()
405
+ expected_df = self .spark .createDataFrame ([(10 ,), (20 ,), (30 ,)], "value int" )
406
+ assertDataFrameEqual (result_df , expected_df )
407
+
408
+ def test_return_type_coercion_overflow (self ):
409
+ @arrow_udtf (returnType = "value int" )
410
+ class CoercionOverflowUDTF :
411
+ def eval (self ) -> Iterator ["pa.Table" ]:
412
+ # Return values that will cause overflow when casting long to int
413
+ result_table = pa .table (
414
+ {
415
+ "value" : pa .array ([2147483647 + 1 ], type = pa .int64 ()), # int32 max + 1
416
+ }
417
+ )
418
+ yield result_table
419
+
420
+ # Should fail with PyArrow overflow exception
421
+ with self .assertRaises (Exception ):
422
+ result_df = CoercionOverflowUDTF ()
423
+ result_df .collect ()
424
+
425
+ def test_return_type_coercion_multiple_columns (self ):
426
+ @arrow_udtf (returnType = "id int, price float" )
427
+ class MultipleColumnCoercionUDTF :
428
+ def eval (self ) -> Iterator ["pa.Table" ]:
429
+ result_table = pa .table (
430
+ {
431
+ "id" : pa .array ([1 , 2 , 3 ], type = pa .int64 ()), # long -> int coercion
432
+ "price" : pa .array (
433
+ [10.5 , 20.7 , 30.9 ], type = pa .float64 ()
434
+ ), # double -> float coercion
435
+ }
436
+ )
437
+ yield result_table
438
+
439
+ result_df = MultipleColumnCoercionUDTF ()
440
+ expected_df = self .spark .createDataFrame (
441
+ [(1 , 10.5 ), (2 , 20.7 ), (3 , 30.9 )], "id int, price float"
442
+ )
443
+ assertDataFrameEqual (result_df , expected_df )
444
+
353
445
def test_arrow_udtf_with_empty_column_result (self ):
354
446
@arrow_udtf (returnType = StructType ())
355
447
class EmptyResultUDTF :
0 commit comments