Skip to content

Commit cbbab52

Browse files
committed
Ruff updates after 3.10
1 parent 5d891fe commit cbbab52

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

python/datafusion/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def with_columns(
586586
if isinstance(expr, str):
587587
expressions.append(self.parse_sql_expr(expr).expr)
588588
elif isinstance(expr, Iterable) and not isinstance(
589-
expr, (Expr, str, bytes, bytearray)
589+
expr, Expr | str | bytes | bytearray
590590
):
591591
expressions.extend(
592592
[
@@ -639,7 +639,7 @@ def aggregate(
639639
"""
640640
group_by_list = (
641641
list(group_by)
642-
if isinstance(group_by, Sequence) and not isinstance(group_by, (Expr, str))
642+
if isinstance(group_by, Sequence) and not isinstance(group_by, Expr | str)
643643
else [group_by]
644644
)
645645
aggs_list = (

python/datafusion/expr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def _iter(
271271
) -> Iterable[expr_internal.Expr]:
272272
for expr in items:
273273
if isinstance(expr, Iterable) and not isinstance(
274-
expr, (Expr, str, bytes, bytearray)
274+
expr, Expr | str | bytes | bytearray
275275
):
276276
# Treat string-like objects as atomic to surface standard errors
277277
yield from _iter(expr)
@@ -308,7 +308,7 @@ def expr_list_to_raw_expr_list(
308308
expr_list: Optional[list[Expr] | Expr],
309309
) -> Optional[list[expr_internal.Expr]]:
310310
"""Convert a sequence of expressions or column names to raw expressions."""
311-
if isinstance(expr_list, (Expr, str)):
311+
if isinstance(expr_list, Expr | str):
312312
expr_list = [expr_list]
313313
if expr_list is None:
314314
return None
@@ -326,7 +326,7 @@ def sort_list_to_raw_sort_list(
326326
sort_list: Optional[_typing.Union[Sequence[SortKey], SortKey]],
327327
) -> Optional[list[expr_internal.SortExpr]]:
328328
"""Helper function to return an optional sort list to raw variant."""
329-
if isinstance(sort_list, (Expr, SortExpr, str)):
329+
if isinstance(sort_list, Expr | SortExpr | str):
330330
sort_list = [sort_list]
331331
if sort_list is None:
332332
return None

python/tests/test_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def test_array_functions(stmt, py_expr):
567567

568568
col = column("arr")
569569
query_result = df.select(stmt(col)).collect()[0].column(0)
570-
for a, b in zip(query_result, py_expr(data)):
570+
for a, b in zip(query_result, py_expr(data), strict=False):
571571
np.testing.assert_array_almost_equal(
572572
np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
573573
)
@@ -582,7 +582,7 @@ def test_array_function_flatten():
582582
stmt = f.flatten(literal(data))
583583
py_expr = [py_flatten(data)]
584584
query_result = df.select(stmt).collect()[0].column(0)
585-
for a, b in zip(query_result, py_expr):
585+
for a, b in zip(query_result, py_expr, strict=False):
586586
np.testing.assert_array_almost_equal(
587587
np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
588588
)
@@ -600,7 +600,7 @@ def test_array_function_cardinality():
600600

601601
query_result = df.select(stmt).collect()[0].column(0)
602602

603-
for a, b in zip(query_result, py_expr):
603+
for a, b in zip(query_result, py_expr, strict=False):
604604
np.testing.assert_array_equal(
605605
np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
606606
)
@@ -631,7 +631,7 @@ def test_make_array_functions(make_func):
631631
]
632632

633633
query_result = df.select(stmt).collect()[0].column(0)
634-
for a, b in zip(query_result, py_expr):
634+
for a, b in zip(query_result, py_expr, strict=False):
635635
np.testing.assert_array_equal(
636636
np.array(a.as_py(), dtype=str), np.array(b, dtype=str)
637637
)
@@ -664,7 +664,7 @@ def test_array_function_obj_tests(stmt, py_expr):
664664
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
665665
df = ctx.create_dataframe([[batch]])
666666
query_result = np.array(df.select(stmt).collect()[0].column(0))
667-
for a, b in zip(query_result, py_expr(data)):
667+
for a, b in zip(query_result, py_expr(data), strict=False):
668668
assert a == b
669669

670670

python/tests/test_sql.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_ty
194194
result = pa.Table.from_batches(result)
195195

196196
rd = result.to_pydict()
197-
assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1}
197+
assert dict(zip(rd["grp"], rd["cnt"], strict=False)) == {"a": 3, "b": 1}
198198

199199

200200
@pytest.mark.parametrize("path_to_str", [True, False])
@@ -340,7 +340,10 @@ def test_execute(ctx, tmp_path):
340340
result_values.extend(pydict["cnt"])
341341

342342
result_keys, result_values = (
343-
list(t) for t in zip(*sorted(zip(result_keys, result_values)))
343+
list(t)
344+
for t in zip(
345+
*sorted(zip(result_keys, result_values, strict=False)), strict=False
346+
)
344347
)
345348

346349
assert result_keys == [1, 2, 3, 11, 12]
@@ -467,7 +470,7 @@ def test_simple_select(ctx, tmp_path, arr):
467470
# In DF 43.0.0 we now default to having BinaryView and StringView
468471
# so the array that is saved to the parquet is slightly different
469472
# than the array read. Convert to values for comparison.
470-
if isinstance(result, (pa.BinaryViewArray, pa.StringViewArray)):
473+
if isinstance(result, pa.BinaryViewArray | pa.StringViewArray):
471474
arr = arr.tolist()
472475
result = result.tolist()
473476

@@ -524,12 +527,12 @@ def test_register_listing_table(
524527
result = pa.Table.from_batches(result)
525528

526529
rd = result.to_pydict()
527-
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}
530+
assert dict(zip(rd["grp"], rd["count"], strict=False)) == {"a": 5, "b": 2}
528531

529532
result = ctx.sql(
530533
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501
531534
).collect()
532535
result = pa.Table.from_batches(result)
533536

534537
rd = result.to_pydict()
535-
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}
538+
assert dict(zip(rd["grp"], rd["count"], strict=False)) == {"a": 3, "b": 2}

0 commit comments

Comments
 (0)