From f9cafb89cdb74952eb8c96979b62aa7371b875a0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 1 Sep 2025 11:23:19 +0800 Subject: [PATCH 01/11] refactor: improve DataFrame expression handling, type checking, and docs - Refactor expression handling and `_simplify_expression` for stronger type checking and clearer error handling - Improve type annotations for `file_sort_order` and `order_by` to support string inputs - Refactor DataFrame `filter` method to better validate predicates - Replace internal error message variable with public constant - Clarify usage of `col()` and `column()` in DataFrame examples --- docs/source/user-guide/dataframe/index.rst | 45 +++++++++++++ python/datafusion/context.py | 6 +- python/datafusion/dataframe.py | 77 ++++++++++++++-------- python/datafusion/expr.py | 49 +++++++++++--- python/datafusion/functions.py | 28 ++++---- python/tests/test_dataframe.py | 65 ++++++++++++++++++ 6 files changed, 218 insertions(+), 52 deletions(-) diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index f69485af7..22c0de4f6 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -126,6 +126,51 @@ DataFusion's DataFrame API offers a wide range of operations: # Drop columns df = df.drop("temporary_column") +String Columns and Expressions +------------------------------ + +Some ``DataFrame`` methods accept plain strings when an argument refers to an +existing column. These include: + +* :py:meth:`~datafusion.DataFrame.select` +* :py:meth:`~datafusion.DataFrame.sort` +* :py:meth:`~datafusion.DataFrame.drop` +* :py:meth:`~datafusion.DataFrame.join` (``on`` argument) +* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns) + +For such methods, you can pass column names directly: + +.. code-block:: python + + from datafusion import col, column, functions as f + + df.sort('id') + df.aggregate('id', [f.count(col('value'))]) + +The same operation can also be written with an explicit column expression: + +.. code-block:: python + + from datafusion import col, column, functions as f + + df.sort(col('id')) + df.aggregate(col('id'), [f.count(col('value'))]) + +Note that ``column()`` is an alias of ``col()``, so you can use either name. + +Whenever an argument represents an expression—such as in +:py:meth:`~datafusion.DataFrame.filter` or +:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference columns +and wrap constant values with ``lit()`` (also available as ``literal()``): + +.. code-block:: python + + from datafusion import col, lit + df.filter(col('age') > lit(21)) + +Without ``lit()`` DataFusion would treat ``21`` as a column name rather than a +constant value. + Terminal Operations ------------------- diff --git a/python/datafusion/context.py b/python/datafusion/context.py index bce51d644..2ea68cff7 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -553,7 +553,7 @@ def register_listing_table( table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr]] | None = None, + file_sort_order: list[list[Expr | SortExpr | str]] | None = None, ) -> None: """Register multiple files as a single table. @@ -808,7 +808,7 @@ def register_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[SortExpr]] | None = None, + file_sort_order: list[list[Expr | SortExpr | str]] | None = None, ) -> None: """Register a Parquet file as a table. @@ -1099,7 +1099,7 @@ def read_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr]] | None = None, + file_sort_order: list[list[Expr | SortExpr | str]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 61cb09438..89dbef221 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -40,7 +40,13 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal -from datafusion.expr import Expr, SortExpr, sort_or_default +from datafusion.expr import ( + _EXPR_TYPE_ERROR, + Expr, + SortExpr, + expr_list_to_raw_expr_list, + sort_list_to_raw_sort_list, +) from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream @@ -394,9 +400,7 @@ def select(self, *exprs: Expr | str) -> DataFrame: df = df.select("a", col("b"), col("a").alias("alternate_a")) """ - exprs_internal = [ - Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs - ] + exprs_internal = expr_list_to_raw_expr_list(exprs) return DataFrame(self.df.select(*exprs_internal)) def drop(self, *columns: str) -> DataFrame: @@ -426,6 +430,8 @@ def filter(self, *predicates: Expr) -> DataFrame: """ df = self.df for p in predicates: + if not isinstance(p, Expr): + raise TypeError(_EXPR_TYPE_ERROR) df = df.filter(p.expr) return DataFrame(df) @@ -439,6 +445,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: Returns: DataFrame with the new column. """ + if not isinstance(expr, Expr): + raise TypeError(_EXPR_TYPE_ERROR) return DataFrame(self.df.with_column(name, expr.expr)) def with_columns( @@ -468,17 +476,22 @@ def with_columns( def _simplify_expression( *exprs: Expr | Iterable[Expr], **named_exprs: Expr ) -> list[expr_internal.Expr]: - expr_list = [] + expr_list: list[expr_internal.Expr] = [] for expr in exprs: - if isinstance(expr, Expr): - expr_list.append(expr.expr) - elif isinstance(expr, Iterable): - expr_list.extend(inner_expr.expr for inner_expr in expr) - else: - raise NotImplementedError - if named_exprs: - for alias, expr in named_exprs.items(): - expr_list.append(expr.alias(alias).expr) + if isinstance(expr, str) or ( + isinstance(expr, Iterable) + and not isinstance(expr, Expr) + and any(isinstance(inner, str) for inner in expr) + ): + raise TypeError(_EXPR_TYPE_ERROR) + try: + expr_list.extend(expr_list_to_raw_expr_list(expr)) + except TypeError as err: + raise TypeError(_EXPR_TYPE_ERROR) from err + for alias, expr in named_exprs.items(): + if not isinstance(expr, Expr): + raise TypeError(_EXPR_TYPE_ERROR) + expr_list.append(expr.alias(alias).expr) return expr_list expressions = _simplify_expression(*exprs, **named_exprs) @@ -503,37 +516,43 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: return DataFrame(self.df.with_column_renamed(old_name, new_name)) def aggregate( - self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr + self, + group_by: list[Expr | str] | Expr | str, + aggs: list[Expr] | Expr, ) -> DataFrame: """Aggregates the rows of the current DataFrame. Args: - group_by: List of expressions to group by. + group_by: List of expressions or column names to group by. aggs: List of expressions to aggregate. Returns: DataFrame after aggregation. """ - group_by = group_by if isinstance(group_by, list) else [group_by] - aggs = aggs if isinstance(aggs, list) else [aggs] + group_by_list = group_by if isinstance(group_by, list) else [group_by] + aggs_list = aggs if isinstance(aggs, list) else [aggs] - group_by = [e.expr for e in group_by] - aggs = [e.expr for e in aggs] - return DataFrame(self.df.aggregate(group_by, aggs)) + group_by_exprs = expr_list_to_raw_expr_list(group_by_list) + aggs_exprs = [] + for agg in aggs_list: + if not isinstance(agg, Expr): + raise TypeError(_EXPR_TYPE_ERROR) + aggs_exprs.append(agg.expr) + return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs)) - def sort(self, *exprs: Expr | SortExpr) -> DataFrame: - """Sort the DataFrame by the specified sorting expressions. + def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame: + """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by - calling its` ``sort`` method. + calling its ``sort`` method. Args: - exprs: Sort expressions, applied in order. + exprs: Sort expressions or column names, applied in order. Returns: DataFrame after sorting. """ - exprs_raw = [sort_or_default(expr) for expr in exprs] + exprs_raw = sort_list_to_raw_sort_list(list(exprs)) return DataFrame(self.df.sort(*exprs_raw)) def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: @@ -757,7 +776,11 @@ def join_on( Returns: DataFrame after join. """ - exprs = [expr.expr for expr in on_exprs] + exprs = [] + for expr in on_exprs: + if not isinstance(expr, Expr): + raise TypeError(_EXPR_TYPE_ERROR) + exprs.append(expr.expr) return DataFrame(self.df.join_on(right.df, exprs, how)) def explain(self, verbose: bool = False, analyze: bool = False) -> None: diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index c0b495717..ca106b60a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence import pyarrow as pa @@ -39,6 +39,10 @@ if TYPE_CHECKING: from datafusion.plan import LogicalPlan + +# Standard error message for invalid expression types +_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions" + # The following are imported from the internal representation. We may choose to # give these all proper wrappers, or to simply leave as is. These were added # in order to support passing the `test_imports` unit test. @@ -216,12 +220,26 @@ def expr_list_to_raw_expr_list( - expr_list: Optional[list[Expr] | Expr], + expr_list: Optional[Sequence[Expr | str] | Expr | str], ) -> Optional[list[expr_internal.Expr]]: - """Helper function to convert an optional list to raw expressions.""" - if isinstance(expr_list, Expr): + """Convert a sequence of expressions or column names to raw expressions.""" + if isinstance(expr_list, (Expr, str)): expr_list = [expr_list] - return [e.expr for e in expr_list] if expr_list is not None else None + if expr_list is None: + return None + raw_exprs: list[expr_internal.Expr] = [] + for e in expr_list: + if isinstance(e, str): + raw_exprs.append(Expr.column(e).expr) + elif isinstance(e, Expr): + raw_exprs.append(e.expr) + else: + error = ( + "Expected Expr or column name, found:" + f" {type(e).__name__}. {_EXPR_TYPE_ERROR}." + ) + raise TypeError(error) + return raw_exprs def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: @@ -232,12 +250,27 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: def sort_list_to_raw_sort_list( - sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr], + sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" - if isinstance(sort_list, (Expr, SortExpr)): + if isinstance(sort_list, (Expr, SortExpr, str)): sort_list = [sort_list] - return [sort_or_default(e) for e in sort_list] if sort_list is not None else None + if sort_list is None: + return None + raw_sort_list = [] + for item in sort_list: + if isinstance(item, str): + expr_obj = Expr.column(item) + elif isinstance(item, (Expr, SortExpr)): + expr_obj = item + else: + error = ( + "Expected Expr or column name, found:" + f" {type(item).__name__}. {_EXPR_TYPE_ERROR}." + ) + raise TypeError(error) + raw_sort_list.append(sort_or_default(expr_obj)) + return raw_sort_list class Expr: diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 34068805c..bb95b042e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -429,7 +429,7 @@ def window( name: str, args: list[Expr], partition_by: list[Expr] | Expr | None = None, - order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None, + order_by: list[Expr | SortExpr | str] | Expr | SortExpr | str | None = None, window_frame: WindowFrame | None = None, ctx: SessionContext | None = None, ) -> Expr: @@ -1723,7 +1723,7 @@ def array_agg( expression: Expr, distinct: bool = False, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Aggregate values into an array. @@ -2222,7 +2222,7 @@ def regr_syy( def first_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the first value in a group of values. @@ -2254,7 +2254,7 @@ def first_value( def last_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. @@ -2287,7 +2287,7 @@ def nth_value( expression: Expr, n: int, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the n-th value in a group of values. @@ -2408,7 +2408,7 @@ def lead( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a lead window function. @@ -2461,7 +2461,7 @@ def lag( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a lag window function. @@ -2508,7 +2508,7 @@ def lag( def row_number( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a row number window function. @@ -2542,7 +2542,7 @@ def row_number( def rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a rank window function. @@ -2581,7 +2581,7 @@ def rank( def dense_rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a dense_rank window function. @@ -2615,7 +2615,7 @@ def dense_rank( def percent_rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a percent_rank window function. @@ -2650,7 +2650,7 @@ def percent_rank( def cume_dist( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a cumulative distribution window function. @@ -2686,7 +2686,7 @@ def cume_dist( def ntile( groups: int, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Create a n-tile window function. @@ -2727,7 +2727,7 @@ def string_agg( expression: Expr, delimiter: str, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, ) -> Expr: """Concatenates the input strings. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 0cd56219a..92747f46d 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -33,6 +33,7 @@ WindowFrame, column, literal, + col, ) from datafusion import ( functions as f, @@ -227,6 +228,13 @@ def test_select_mixed_expr_string(df): assert result.column(1) == pa.array([1, 2, 3]) +def test_select_unsupported(df): + with pytest.raises( + TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)" + ): + df.select(1) + + def test_filter(df): df1 = df.filter(column("a") > literal(2)).select( column("a") + column("b"), @@ -268,6 +276,32 @@ def test_sort(df): assert table.to_pydict() == expected +def test_sort_string_and_expression_equivalent(df): + from datafusion import col + + result_str = df.sort("a").to_pydict() + result_expr = df.sort(col("a")).to_pydict() + assert result_str == result_expr + + +def test_sort_unsupported(df): + with pytest.raises( + TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)" + ): + df.sort(1) + + +def test_aggregate_string_and_expression_equivalent(df): + result_str = df.aggregate("a", [f.count()]).to_pydict() + result_expr = df.aggregate(col("a"), [f.count()]).to_pydict() + assert result_str == result_expr + + +def test_filter_string_unsupported(df): + with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"): + df.filter("a > 1") + + def test_drop(df): df = df.drop("c") @@ -337,6 +371,11 @@ def test_with_column(df): assert result.column(2) == pa.array([5, 7, 9]) +def test_with_column_invalid_expr(df): + with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + df.with_column("c", "a") + + def test_with_columns(df): df = df.with_columns( (column("a") + column("b")).alias("c"), @@ -368,6 +407,13 @@ def test_with_columns(df): assert result.column(6) == pa.array([5, 7, 9]) +def test_with_columns_invalid_expr(df): + with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + df.with_columns("a") + with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + df.with_columns(c="a") + + def test_cast(df): df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())}) expected = pa.schema( @@ -526,6 +572,25 @@ def test_join_on(): assert table.to_pydict() == expected +def test_join_on_invalid_expr(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([4, 5])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], "l") + df1 = ctx.create_dataframe([[batch]], "r") + + with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + df.join_on(df1, "a") + + +def test_aggregate_invalid_aggs(df): + with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + df.aggregate([], "a") + + def test_distinct(): ctx = SessionContext() From 91167b085c1a7d9e7aab19efa25d6ff41c83558d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 1 Sep 2025 18:32:05 +0800 Subject: [PATCH 02/11] refactor: unify expression and sorting logic; improve docs and error handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update `order_by` handling in Window class for better type support - Improve type checking in DataFrame expression handling - Replace `Expr`/`SortExpr` with `SortKey` in file_sort_order and related functions - Simplify file_sort_order handling in SessionContext - Rename `_EXPR_TYPE_ERROR` → `EXPR_TYPE_ERROR` for consistency - Clarify usage of `col()` vs `column()` in DataFrame examples - Enhance documentation for file_sort_order in SessionContext --- docs/source/user-guide/dataframe/index.rst | 8 ++-- python/datafusion/context.py | 47 ++++++++++++---------- python/datafusion/dataframe.py | 34 ++++++++-------- python/datafusion/expr.py | 19 +++++---- python/datafusion/functions.py | 29 ++++++------- python/tests/test_dataframe.py | 14 +++++-- 6 files changed, 85 insertions(+), 66 deletions(-) diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 22c0de4f6..95abf3f3b 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -142,21 +142,21 @@ For such methods, you can pass column names directly: .. code-block:: python - from datafusion import col, column, functions as f + from datafusion import col, functions as f df.sort('id') df.aggregate('id', [f.count(col('value'))]) -The same operation can also be written with an explicit column expression: +The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``: .. code-block:: python from datafusion import col, column, functions as f df.sort(col('id')) - df.aggregate(col('id'), [f.count(col('value'))]) + df.aggregate(column('id'), [f.count(col('value'))]) -Note that ``column()`` is an alias of ``col()``, so you can use either name. +Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action. Whenever an argument represents an expression—such as in :py:meth:`~datafusion.DataFrame.filter` or diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 2ea68cff7..8f05aed2f 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -31,7 +31,7 @@ from datafusion.catalog import Catalog, CatalogProvider, Table from datafusion.dataframe import DataFrame -from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list +from datafusion.expr import SortKey, sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF @@ -553,7 +553,7 @@ def register_listing_table( table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr | str]] | None = None, + file_sort_order: list[list[SortKey]] | None = None, ) -> None: """Register multiple files as a single table. @@ -567,23 +567,20 @@ def register_listing_table( table_partition_cols: Partition columns. file_extension: File extension of the provided table. schema: The data source schema. - file_sort_order: Sort order for the file. + file_sort_order: Sort order for the file. Each sort key can be + specified as a column name (``str``), an expression + (``Expr``), or a ``SortExpr``. """ if table_partition_cols is None: table_partition_cols = [] table_partition_cols = self._convert_table_partition_cols(table_partition_cols) - file_sort_order_raw = ( - [sort_list_to_raw_sort_list(f) for f in file_sort_order] - if file_sort_order is not None - else None - ) self.ctx.register_listing_table( name, str(path), table_partition_cols, file_extension, schema, - file_sort_order_raw, + self._convert_file_sort_order(file_sort_order), ) def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: @@ -808,7 +805,7 @@ def register_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr | str]] | None = None, + file_sort_order: list[list[SortKey]] | None = None, ) -> None: """Register a Parquet file as a table. @@ -827,7 +824,9 @@ def register_parquet( that may be in the file schema. This can help avoid schema conflicts due to metadata. schema: The data source schema. - file_sort_order: Sort order for the file. + file_sort_order: Sort order for the file. Each sort key can be + specified as a column name (``str``), an expression + (``Expr``), or a ``SortExpr``. """ if table_partition_cols is None: table_partition_cols = [] @@ -840,9 +839,7 @@ def register_parquet( file_extension, skip_metadata, schema, - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None, + self._convert_file_sort_order(file_sort_order), ) def register_csv( @@ -1099,7 +1096,7 @@ def read_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr | str]] | None = None, + file_sort_order: list[list[SortKey]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -1116,7 +1113,9 @@ def read_parquet( schema: An optional schema representing the parquet files. If None, the parquet reader will try to infer it based on data in the file. - file_sort_order: Sort order for the file. + file_sort_order: Sort order for the file. Each sort key can be + specified as a column name (``str``), an expression + (``Expr``), or a ``SortExpr``. Returns: DataFrame representation of the read Parquet files @@ -1124,11 +1123,7 @@ def read_parquet( if table_partition_cols is None: table_partition_cols = [] table_partition_cols = self._convert_table_partition_cols(table_partition_cols) - file_sort_order = ( - [sort_list_to_raw_sort_list(f) for f in file_sort_order] - if file_sort_order is not None - else None - ) + file_sort_order = self._convert_file_sort_order(file_sort_order) return DataFrame( self.ctx.read_parquet( str(path), @@ -1179,6 +1174,16 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions)) + @staticmethod + def _convert_file_sort_order( + file_sort_order: list[list[Expr | SortExpr | str]] | None, + ) -> list[list[Any]] | None: + return ( + [sort_list_to_raw_sort_list(f) for f in file_sort_order] + if file_sort_order is not None + else None + ) + @staticmethod def _convert_table_partition_cols( table_partition_cols: list[tuple[str, str | pa.DataType]], diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 89dbef221..dc19836d8 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -41,9 +41,9 @@ from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import ( - _EXPR_TYPE_ERROR, + EXPR_TYPE_ERROR, Expr, - SortExpr, + SortKey, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, ) @@ -431,7 +431,7 @@ def filter(self, *predicates: Expr) -> DataFrame: df = self.df for p in predicates: if not isinstance(p, Expr): - raise TypeError(_EXPR_TYPE_ERROR) + raise TypeError(EXPR_TYPE_ERROR) df = df.filter(p.expr) return DataFrame(df) @@ -446,7 +446,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: DataFrame with the new column. """ if not isinstance(expr, Expr): - raise TypeError(_EXPR_TYPE_ERROR) + raise TypeError(EXPR_TYPE_ERROR) return DataFrame(self.df.with_column(name, expr.expr)) def with_columns( @@ -478,19 +478,21 @@ def _simplify_expression( ) -> list[expr_internal.Expr]: expr_list: list[expr_internal.Expr] = [] for expr in exprs: - if isinstance(expr, str) or ( - isinstance(expr, Iterable) - and not isinstance(expr, Expr) - and any(isinstance(inner, str) for inner in expr) - ): - raise TypeError(_EXPR_TYPE_ERROR) + if isinstance(expr, str): + raise TypeError(EXPR_TYPE_ERROR) + if isinstance(expr, Iterable) and not isinstance(expr, Expr): + expr_value = list(expr) + if any(isinstance(inner, str) for inner in expr_value): + raise TypeError(EXPR_TYPE_ERROR) + else: + expr_value = expr try: - expr_list.extend(expr_list_to_raw_expr_list(expr)) + expr_list.extend(expr_list_to_raw_expr_list(expr_value)) except TypeError as err: - raise TypeError(_EXPR_TYPE_ERROR) from err + raise TypeError(EXPR_TYPE_ERROR) from err for alias, expr in named_exprs.items(): if not isinstance(expr, Expr): - raise TypeError(_EXPR_TYPE_ERROR) + raise TypeError(EXPR_TYPE_ERROR) expr_list.append(expr.alias(alias).expr) return expr_list @@ -536,11 +538,11 @@ def aggregate( aggs_exprs = [] for agg in aggs_list: if not isinstance(agg, Expr): - raise TypeError(_EXPR_TYPE_ERROR) + raise TypeError(EXPR_TYPE_ERROR) aggs_exprs.append(agg.expr) return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs)) - def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame: + def sort(self, *exprs: SortKey) -> DataFrame: """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by @@ -779,7 +781,7 @@ def join_on( exprs = [] for expr in on_exprs: if not isinstance(expr, Expr): - raise TypeError(_EXPR_TYPE_ERROR) + raise TypeError(EXPR_TYPE_ERROR) exprs.append(expr.expr) return DataFrame(self.df.join_on(right.df, exprs, how)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index ca106b60a..362ae44c1 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Union import pyarrow as pa @@ -41,7 +41,9 @@ # Standard error message for invalid expression types -_EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions" +EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions" + +SortKey = Union["Expr", "SortExpr", str] # The following are imported from the internal representation. We may choose to # give these all proper wrappers, or to simply leave as is. These were added @@ -199,6 +201,7 @@ "SimilarTo", "Sort", "SortExpr", + "SortKey", "Subquery", "SubqueryAlias", "TableScan", @@ -236,7 +239,7 @@ def expr_list_to_raw_expr_list( else: error = ( "Expected Expr or column name, found:" - f" {type(e).__name__}. {_EXPR_TYPE_ERROR}." + f" {type(e).__name__}. {EXPR_TYPE_ERROR}." ) raise TypeError(error) return raw_exprs @@ -250,7 +253,7 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: def sort_list_to_raw_sort_list( - sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str], + sort_list: Optional[list[SortKey] | SortKey], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" if isinstance(sort_list, (Expr, SortExpr, str)): @@ -266,7 +269,7 @@ def sort_list_to_raw_sort_list( else: error = ( "Expected Expr or column name, found:" - f" {type(item).__name__}. {_EXPR_TYPE_ERROR}." + f" {type(item).__name__}. {EXPR_TYPE_ERROR}." ) raise TypeError(error) raw_sort_list.append(sort_or_default(expr_obj)) @@ -693,7 +696,7 @@ def over(self, window: Window) -> Expr: window: Window definition """ partition_by_raw = expr_list_to_raw_expr_list(window._partition_by) - order_by_raw = sort_list_to_raw_sort_list(window._order_by) + order_by_raw = window._order_by window_frame_raw = ( window._window_frame.window_frame if window._window_frame is not None @@ -1179,7 +1182,7 @@ def __init__( self, partition_by: Optional[list[Expr] | Expr] = None, window_frame: Optional[WindowFrame] = None, - order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None, + order_by: Optional[list[SortExpr | Expr | str] | Expr | SortExpr | str] = None, null_treatment: Optional[NullTreatment] = None, ) -> None: """Construct a window definition. @@ -1192,7 +1195,7 @@ def __init__( """ self._partition_by = partition_by self._window_frame = window_frame - self._order_by = order_by + self._order_by = sort_list_to_raw_sort_list(order_by) self._null_treatment = null_treatment diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index bb95b042e..5f887c074 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -28,6 +28,7 @@ CaseBuilder, Expr, SortExpr, + SortKey, WindowFrame, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, @@ -429,7 +430,7 @@ def window( name: str, args: list[Expr], partition_by: list[Expr] | Expr | None = None, - order_by: list[Expr | SortExpr | str] | Expr | SortExpr | str | None = None, + order_by: list[SortKey] | SortKey | None = None, window_frame: WindowFrame | None = None, ctx: SessionContext | None = None, ) -> Expr: @@ -1723,7 +1724,7 @@ def array_agg( expression: Expr, distinct: bool = False, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Aggregate values into an array. @@ -2222,7 +2223,7 @@ def regr_syy( def first_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the first value in a group of values. @@ -2254,7 +2255,7 @@ def first_value( def last_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. @@ -2287,7 +2288,7 @@ def nth_value( expression: Expr, n: int, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the n-th value in a group of values. @@ -2408,7 +2409,7 @@ def lead( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a lead window function. @@ -2461,7 +2462,7 @@ def lag( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a lag window function. @@ -2508,7 +2509,7 @@ def lag( def row_number( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a row number window function. @@ -2542,7 +2543,7 @@ def row_number( def rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a rank window function. @@ -2581,7 +2582,7 @@ def rank( def dense_rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a dense_rank window function. @@ -2615,7 +2616,7 @@ def dense_rank( def percent_rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a percent_rank window function. @@ -2650,7 +2651,7 @@ def percent_rank( def cume_dist( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a cumulative distribution window function. @@ -2686,7 +2687,7 @@ def cume_dist( def ntile( groups: int, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a n-tile window function. @@ -2727,7 +2728,7 @@ def string_agg( expression: Expr, delimiter: str, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Concatenates the input strings. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 92747f46d..22d7ef58a 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -33,7 +33,6 @@ WindowFrame, column, literal, - col, ) from datafusion import ( functions as f, @@ -292,8 +291,10 @@ def test_sort_unsupported(df): def test_aggregate_string_and_expression_equivalent(df): - result_str = df.aggregate("a", [f.count()]).to_pydict() - result_expr = df.aggregate(col("a"), [f.count()]).to_pydict() + from datafusion import col + + result_str = df.aggregate("a", [f.count()]).sort("a").to_pydict() + result_expr = df.aggregate(col("a"), [f.count()]).sort("a").to_pydict() assert result_str == result_expr @@ -778,6 +779,13 @@ def test_distinct(): ), [1, 1, 1, 1, 5, 5, 5], ), + ( + "first_value_order_by_string", + f.first_value(column("a")).over( + Window(partition_by=[column("c")], order_by="b") + ), + [1, 1, 1, 1, 5, 5, 5], + ), ( "last_value", f.last_value(column("a")).over( From 54687a295085ad648f63259ccc8fb0289aecd68a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 1 Sep 2025 20:17:04 +0800 Subject: [PATCH 03/11] feat: add ensure_expr helper for validation; refine expression handling, sorting, and docs - Introduce `ensure_expr` helper and improve internal expression validation - Update error messages and tests to consistently use `EXPR_TYPE_ERROR` - Refactor expression handling with `_to_raw_expr`, `_ensure_expr`, and `SortKey` - Improve type safety and consistency in sort key definitions and file sort order - Add parameterized parquet sorting tests - Enhance DataFrame docstrings with clearer guidance and usage examples - Fix minor typos and error message clarity --- python/datafusion/context.py | 18 ++++-- python/datafusion/dataframe.py | 109 +++++++++++++++++---------------- python/datafusion/expr.py | 87 +++++++++++++++++--------- python/datafusion/functions.py | 94 ++++++++++++++++++++++++---- python/tests/test_dataframe.py | 64 ++++++++++++++++--- python/tests/test_expr.py | 10 +++ 6 files changed, 273 insertions(+), 109 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 8f05aed2f..2b7b66c68 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -20,7 +20,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, Sequence import pyarrow as pa @@ -553,7 +553,7 @@ def register_listing_table( table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, - file_sort_order: list[list[SortKey]] | None = None, + file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> None: """Register multiple files as a single table. @@ -805,7 +805,7 @@ def register_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[SortKey]] | None = None, + file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> None: """Register a Parquet file as a table. @@ -1096,7 +1096,7 @@ def read_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[SortKey]] | None = None, + file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -1176,8 +1176,16 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: @staticmethod def _convert_file_sort_order( - file_sort_order: list[list[Expr | SortExpr | str]] | None, + file_sort_order: Sequence[Sequence[SortKey]] | None, ) -> list[list[Any]] | None: + """Convert nested ``SortKey`` sequences into raw sort representations. + + Each ``SortKey`` can be a column name string, an ``Expr``, or a + ``SortExpr`` and will be converted using + :func:`datafusion.expr.sort_list_to_raw_sort_list`. + """ + # Convert each ``SortKey`` in the provided sort order to the low-level + # representation expected by the Rust bindings. return ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index dc19836d8..c5a78d4ac 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -41,9 +41,9 @@ from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import ( - EXPR_TYPE_ERROR, Expr, SortKey, + ensure_expr, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, ) @@ -58,8 +58,6 @@ import polars as pl import pyarrow as pa - from datafusion._internal import expr as expr_internal - from enum import Enum @@ -418,9 +416,17 @@ def filter(self, *predicates: Expr) -> DataFrame: """Return a DataFrame for which ``predicate`` evaluates to ``True``. Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered - out. If more than one predicate is provided, these predicates will be - combined as a logical AND. If more complex logic is required, see the - logical operations in :py:mod:`~datafusion.functions`. + out. If more than one predicate is provided, these predicates will be + combined as a logical AND. Each ``predicate`` must be an + :class:`~datafusion.expr.Expr` created using helper functions such as + :func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not + accepted. If more complex logic is required, see the logical operations in + :py:mod:`~datafusion.functions`. + + Example:: + + from datafusion import col, lit + df.filter(col("a") > lit(1)) Args: predicates: Predicate expression(s) to filter the DataFrame. @@ -430,14 +436,21 @@ def filter(self, *predicates: Expr) -> DataFrame: """ df = self.df for p in predicates: - if not isinstance(p, Expr): - raise TypeError(EXPR_TYPE_ERROR) - df = df.filter(p.expr) + df = df.filter(ensure_expr(p)) return DataFrame(df) def with_column(self, name: str, expr: Expr) -> DataFrame: """Add an additional column to the DataFrame. + The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with + :func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not + accepted. + + Example:: + + from datafusion import col, lit + df.with_column("b", col("a") + lit(1)) + Args: name: Name of the column to add. expr: Expression to compute the column. @@ -445,25 +458,27 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: Returns: DataFrame with the new column. """ - if not isinstance(expr, Expr): - raise TypeError(EXPR_TYPE_ERROR) - return DataFrame(self.df.with_column(name, expr.expr)) + return DataFrame(self.df.with_column(name, ensure_expr(expr))) def with_columns( self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr ) -> DataFrame: """Add columns to the DataFrame. - By passing expressions, iteratables of expressions, or named expressions. To - pass named expressions use the form name=Expr. + By passing expressions, iterables of expressions, or named expressions. + All expressions must be :class:`~datafusion.expr.Expr` objects created via + :func:`datafusion.col` or :func:`datafusion.lit`; plain strings are not + accepted. To pass named expressions use the form ``name=Expr``. - Example usage: The following will add 4 columns labeled a, b, c, and d:: + Example usage: The following will add 4 columns labeled ``a``, ``b``, ``c``, + and ``d``:: + from datafusion import col, lit df = df.with_columns( - lit(0).alias('a'), - [lit(1).alias('b'), lit(2).alias('c')], + col("x").alias("a"), + [lit(1).alias("b"), col("y").alias("c")], d=lit(3) - ) + ) Args: exprs: Either a single expression or an iterable of expressions to add. @@ -473,30 +488,19 @@ def with_columns( DataFrame with the new columns added. """ - def _simplify_expression( - *exprs: Expr | Iterable[Expr], **named_exprs: Expr - ) -> list[expr_internal.Expr]: - expr_list: list[expr_internal.Expr] = [] - for expr in exprs: + def _iter_exprs(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[Expr | str]: + for expr in items: if isinstance(expr, str): - raise TypeError(EXPR_TYPE_ERROR) - if isinstance(expr, Iterable) and not isinstance(expr, Expr): - expr_value = list(expr) - if any(isinstance(inner, str) for inner in expr_value): - raise TypeError(EXPR_TYPE_ERROR) + yield expr + elif isinstance(expr, Iterable) and not isinstance(expr, Expr): + yield from _iter_exprs(expr) else: - expr_value = expr - try: - expr_list.extend(expr_list_to_raw_expr_list(expr_value)) - except TypeError as err: - raise TypeError(EXPR_TYPE_ERROR) from err - for alias, expr in named_exprs.items(): - if not isinstance(expr, Expr): - raise TypeError(EXPR_TYPE_ERROR) - expr_list.append(expr.alias(alias).expr) - return expr_list - - expressions = _simplify_expression(*exprs, **named_exprs) + yield expr + + expressions = [ensure_expr(e) for e in _iter_exprs(exprs)] + for alias, expr in named_exprs.items(): + ensure_expr(expr) + expressions.append(expr.alias(alias).expr) return DataFrame(self.df.with_columns(expressions)) @@ -535,11 +539,7 @@ def aggregate( aggs_list = aggs if isinstance(aggs, list) else [aggs] group_by_exprs = expr_list_to_raw_expr_list(group_by_list) - aggs_exprs = [] - for agg in aggs_list: - if not isinstance(agg, Expr): - raise TypeError(EXPR_TYPE_ERROR) - aggs_exprs.append(agg.expr) + aggs_exprs = [ensure_expr(agg) for agg in aggs_list] return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs)) def sort(self, *exprs: SortKey) -> DataFrame: @@ -554,7 +554,7 @@ def sort(self, *exprs: SortKey) -> DataFrame: Returns: DataFrame after sorting. """ - exprs_raw = sort_list_to_raw_sort_list(list(exprs)) + exprs_raw = sort_list_to_raw_sort_list(exprs) return DataFrame(self.df.sort(*exprs_raw)) def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: @@ -766,8 +766,15 @@ def join_on( ) -> DataFrame: """Join two :py:class:`DataFrame` using the specified expressions. - On expressions are used to support in-equality predicates. Equality - predicates are correctly optimized + Join predicates must be :class:`~datafusion.expr.Expr` objects, typically + built with :func:`datafusion.col`; plain strings are not accepted. On + expressions are used to support in-equality predicates. Equality predicates + are correctly optimized. + + Example:: + + from datafusion import col + df.join_on(other_df, col("id") == col("other_id")) Args: right: Other DataFrame to join with. @@ -778,11 +785,7 @@ def join_on( Returns: DataFrame after join. """ - exprs = [] - for expr in on_exprs: - if not isinstance(expr, Expr): - raise TypeError(EXPR_TYPE_ERROR) - exprs.append(expr.expr) + exprs = [ensure_expr(expr) for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) def explain(self, verbose: bool = False, analyze: bool = False) -> None: diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 362ae44c1..362efd6b1 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence import pyarrow as pa @@ -41,9 +41,8 @@ # Standard error message for invalid expression types -EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions" - -SortKey = Union["Expr", "SortExpr", str] +# Mention both alias forms of column and literal helpers +EXPR_TYPE_ERROR = "Use col()/column() or lit()/literal() to construct expressions" # The following are imported from the internal representation. We may choose to # give these all proper wrappers, or to simply leave as is. These were added @@ -219,9 +218,54 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "ensure_expr", ] +def ensure_expr(value: Expr | Any) -> expr_internal.Expr: + """Return the internal expression from ``Expr`` or raise ``TypeError``. + + This helper rejects plain strings and other non-:class:`Expr` values so + higher level APIs consistently require explicit :func:`~datafusion.col` or + :func:`~datafusion.lit` expressions. + + Args: + value: Candidate expression or other object. + + Returns: + The internal expression representation. + + Raises: + TypeError: If ``value`` is not an instance of :class:`Expr`. + """ + if not isinstance(value, Expr): + raise TypeError(EXPR_TYPE_ERROR) + return value.expr + + +def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: + """Convert a Python expression or column name to its raw variant. + + Args: + value: Candidate expression or column name. + + Returns: + The internal :class:`~datafusion._internal.expr.Expr` representation. + + Raises: + TypeError: If ``value`` is neither an :class:`Expr` nor ``str``. + """ + if isinstance(value, str): + return Expr.column(value).expr + if isinstance(value, Expr): + return value.expr + error = ( + "Expected Expr or column name, found:" + f" {type(value).__name__}. {EXPR_TYPE_ERROR}." + ) + raise TypeError(error) + + def expr_list_to_raw_expr_list( expr_list: Optional[Sequence[Expr | str] | Expr | str], ) -> Optional[list[expr_internal.Expr]]: @@ -230,30 +274,18 @@ def expr_list_to_raw_expr_list( expr_list = [expr_list] if expr_list is None: return None - raw_exprs: list[expr_internal.Expr] = [] - for e in expr_list: - if isinstance(e, str): - raw_exprs.append(Expr.column(e).expr) - elif isinstance(e, Expr): - raw_exprs.append(e.expr) - else: - error = ( - "Expected Expr or column name, found:" - f" {type(e).__name__}. {EXPR_TYPE_ERROR}." - ) - raise TypeError(error) - return raw_exprs + return [_to_raw_expr(e) for e in expr_list] def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: - """Helper function to return a default Sort if an Expr is provided.""" + """Return a :class:`SortExpr`, defaulting attributes when necessary.""" if isinstance(e, SortExpr): return e.raw_sort return SortExpr(e, ascending=True, nulls_first=True).raw_sort def sort_list_to_raw_sort_list( - sort_list: Optional[list[SortKey] | SortKey], + sort_list: Optional[Sequence[SortKey] | SortKey], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" if isinstance(sort_list, (Expr, SortExpr, str)): @@ -262,17 +294,11 @@ def sort_list_to_raw_sort_list( return None raw_sort_list = [] for item in sort_list: - if isinstance(item, str): - expr_obj = Expr.column(item) - elif isinstance(item, (Expr, SortExpr)): - expr_obj = item + if isinstance(item, SortExpr): + raw_sort_list.append(sort_or_default(item)) else: - error = ( - "Expected Expr or column name, found:" - f" {type(item).__name__}. {EXPR_TYPE_ERROR}." - ) - raise TypeError(error) - raw_sort_list.append(sort_or_default(expr_obj)) + raw_expr = _to_raw_expr(item) # may raise ``TypeError`` + raw_sort_list.append(sort_or_default(Expr(raw_expr))) return raw_sort_list @@ -1335,3 +1361,6 @@ def nulls_first(self) -> bool: def __repr__(self) -> str: """Generate a string representation of this expression.""" return self.raw_sort.__repr__() + + +SortKey = Expr | SortExpr | str diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 5f887c074..ccf16c80c 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -441,6 +441,10 @@ def window( lag use:: df.select(functions.lag(col("a")).partition_by(col("b")).build()) + + The ``order_by`` parameter accepts column names or expressions, e.g.:: + + window("lag", [col("a")], order_by="ts") """ args = [a.expr for a in args] partition_by_raw = expr_list_to_raw_expr_list(partition_by) @@ -1739,7 +1743,11 @@ def array_agg( expression: Values to combine into an array distinct: If True, a single entry for each distinct value will be in the result filter: If provided, only compute against rows for which the filter is True - order_by: Order the resultant array values + order_by: Order the resultant array values. Accepts column names or expressions. + + For example:: + + df.select(array_agg(col("a"), order_by="b")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2236,8 +2244,13 @@ def first_value( Args: expression: Argument to perform bitwise calculation on filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. null_treatment: Assign whether to respect or ignore null values. + + For example:: + + df.select(first_value(col("a"), order_by="ts")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2268,8 +2281,13 @@ def last_value( Args: expression: Argument to perform bitwise calculation on filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. null_treatment: Assign whether to respect or ignore null values. + + For example:: + + df.select(last_value(col("a"), order_by="ts")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2302,8 +2320,13 @@ def nth_value( expression: Argument to perform bitwise calculation on n: Index of value to return. Starts at 1. filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. null_treatment: Assign whether to respect or ignore null values. + + For example:: + + df.select(nth_value(col("a"), 2, order_by="ts")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2438,7 +2461,12 @@ def lead( shift_offset: Number of rows following the current row. default_value: Value to return if shift_offet row does not exist. partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + lead(col("b"), order_by="ts") """ if not isinstance(default_value, pa.Scalar) and default_value is not None: default_value = pa.scalar(default_value) @@ -2488,7 +2516,12 @@ def lag( shift_offset: Number of rows before the current row. default_value: Value to return if shift_offet row does not exist. partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + lag(col("b"), order_by="ts") """ if not isinstance(default_value, pa.Scalar): default_value = pa.scalar(default_value) @@ -2528,7 +2561,12 @@ def row_number( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + row_number(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2567,7 +2605,12 @@ def rank( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + rank(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2601,7 +2644,12 @@ def dense_rank( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + dense_rank(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2636,7 +2684,12 @@ def percent_rank( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + percent_rank(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2671,7 +2724,12 @@ def cume_dist( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + cume_dist(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2710,7 +2768,12 @@ def ntile( Args: groups: Number of groups for the n-tile to be divided into. partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + ntile(3, order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2743,7 +2806,12 @@ def string_agg( expression: Argument to perform bitwise calculation on delimiter: Text to place between each value of expression filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. + + For example:: + + df.select(string_agg(col("a"), ",", order_by="b")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 22d7ef58a..aac091f98 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -34,6 +34,9 @@ column, literal, ) +from datafusion import ( + col as df_col, +) from datafusion import ( functions as f, ) @@ -43,7 +46,7 @@ get_formatter, reset_formatter, ) -from datafusion.expr import Window +from datafusion.expr import EXPR_TYPE_ERROR, Window from pyarrow.csv import write_csv MB = 1024 * 1024 @@ -229,7 +232,8 @@ def test_select_mixed_expr_string(df): def test_select_unsupported(df): with pytest.raises( - TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)" + TypeError, + match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}", ): df.select(1) @@ -285,7 +289,8 @@ def test_sort_string_and_expression_equivalent(df): def test_sort_unsupported(df): with pytest.raises( - TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)" + TypeError, + match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}", ): df.sort(1) @@ -299,7 +304,7 @@ def test_aggregate_string_and_expression_equivalent(df): def test_filter_string_unsupported(df): - with pytest.raises(TypeError, match=r"col\(\) or lit\(\)"): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.filter("a > 1") @@ -373,7 +378,9 @@ def test_with_column(df): def test_with_column_invalid_expr(df): - with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): df.with_column("c", "a") @@ -409,9 +416,13 @@ def test_with_columns(df): def test_with_columns_invalid_expr(df): - with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): df.with_columns("a") - with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): df.with_columns(c="a") @@ -583,12 +594,16 @@ def test_join_on_invalid_expr(): df = ctx.create_dataframe([[batch]], "l") df1 = ctx.create_dataframe([[batch]], "r") - with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): df.join_on(df1, "a") def test_aggregate_invalid_aggs(df): - with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): df.aggregate([], "a") @@ -2753,3 +2768,34 @@ def test_show_from_empty_batch(capsys) -> None: ctx.create_dataframe([[batch]]).show() out = capsys.readouterr().out assert "| a |" in out + + +@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]]) +def test_register_parquet_file_sort_order(ctx, tmp_path, file_sort_order): + table = pa.table({"a": [1, 2]}) + path = tmp_path / "file.parquet" + pa.parquet.write_table(table, path) + ctx.register_parquet("t", path, file_sort_order=file_sort_order) + assert "t" in ctx.catalog().schema().names() + + +@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]]) +def test_register_listing_table_file_sort_order(ctx, tmp_path, file_sort_order): + table = pa.table({"a": [1, 2]}) + dir_path = tmp_path / "dir" + dir_path.mkdir() + pa.parquet.write_table(table, dir_path / "file.parquet") + ctx.register_listing_table( + "t", dir_path, schema=table.schema, file_sort_order=file_sort_order + ) + assert "t" in ctx.catalog().schema().names() + + +@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]]) +def test_read_parquet_file_sort_order(tmp_path, file_sort_order): + ctx = SessionContext() + table = pa.table({"a": [1, 2]}) + path = tmp_path / "data.parquet" + pa.parquet.write_table(table, path) + df = ctx.read_parquet(path, file_sort_order=file_sort_order) + assert df.collect()[0].column(0).to_pylist() == [1, 2] diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index cfeb07c1f..b3e280ed6 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import re from datetime import datetime, timezone import pyarrow as pa @@ -28,6 +29,7 @@ literal_with_metadata, ) from datafusion.expr import ( + EXPR_TYPE_ERROR, Aggregate, AggregateFunction, BinaryExpr, @@ -47,6 +49,7 @@ TransactionEnd, TransactionStart, Values, + ensure_expr, ) @@ -880,3 +883,10 @@ def test_literal_metadata(ctx): for expected_field in expected_schema: actual_field = result[0].schema.field(expected_field.name) assert expected_field.metadata == actual_field.metadata + + +def test_ensure_expr(): + e = col("a") + assert ensure_expr(e) is e.expr + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr("a") From f591617c0a3cc46ef76f7f1c440b04185d5c93c0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 12:47:28 +0800 Subject: [PATCH 04/11] Refactor and enhance expression handling, test coverage, and documentation - Introduced `ensure_expr_list` to validate and flatten nested expressions, treating strings as atomic - Updated expression utilities to improve consistency across aggregation and window functions - Consolidated and expanded parameterized tests for string equivalence in ranking and window functions - Exposed `EXPR_TYPE_ERROR` for consistent error messaging across modules and tests - Improved internal sort logic using `expr_internal.SortExpr` - Clarified expectations for `join_on` expressions in documentation - Standardized imports and improved test clarity for maintainability --- docs/source/user-guide/dataframe/index.rst | 2 + python/datafusion/context.py | 8 +- python/datafusion/dataframe.py | 38 ++++---- python/datafusion/expr.py | 33 ++++++- python/tests/test_dataframe.py | 108 +++++++++++++++++++-- python/tests/test_expr.py | 16 +++ 6 files changed, 176 insertions(+), 29 deletions(-) diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 95abf3f3b..11cd8b3dc 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -138,6 +138,8 @@ existing column. These include: * :py:meth:`~datafusion.DataFrame.join` (``on`` argument) * :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns) +Note that :py:meth:`~datafusion.DataFrame.join_on` expects ``col()``/``column()`` expressions rather than plain strings. + For such methods, you can pass column names directly: .. code-block:: python diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 2b7b66c68..34130ff9c 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -20,7 +20,8 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Protocol, Sequence +from typing import TYPE_CHECKING, Any, Protocol +from collections.abc import Sequence import pyarrow as pa @@ -39,6 +40,7 @@ from ._internal import SessionConfig as SessionConfigInternal from ._internal import SessionContext as SessionContextInternal from ._internal import SQLOptions as SQLOptionsInternal +from ._internal import expr as expr_internal if TYPE_CHECKING: import pathlib @@ -1177,8 +1179,8 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: @staticmethod def _convert_file_sort_order( file_sort_order: Sequence[Sequence[SortKey]] | None, - ) -> list[list[Any]] | None: - """Convert nested ``SortKey`` sequences into raw sort representations. + ) -> list[list[expr_internal.SortExpr]] | None: + """Convert nested ``SortKey`` sequences into raw sort expressions. Each ``SortKey`` can be a column name string, an ``Expr``, or a ``SortExpr`` and will be converted using diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c5a78d4ac..b91b5babb 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -22,6 +22,7 @@ from __future__ import annotations import warnings +from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, @@ -44,6 +45,7 @@ Expr, SortKey, ensure_expr, + ensure_expr_list, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, ) @@ -52,7 +54,7 @@ if TYPE_CHECKING: import pathlib - from typing import Callable, Sequence + from typing import Callable import pandas as pd import polars as pl @@ -487,17 +489,7 @@ def with_columns( Returns: DataFrame with the new columns added. """ - - def _iter_exprs(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[Expr | str]: - for expr in items: - if isinstance(expr, str): - yield expr - elif isinstance(expr, Iterable) and not isinstance(expr, Expr): - yield from _iter_exprs(expr) - else: - yield expr - - expressions = [ensure_expr(e) for e in _iter_exprs(exprs)] + expressions = ensure_expr_list(exprs) for alias, expr in named_exprs.items(): ensure_expr(expr) expressions.append(expr.alias(alias).expr) @@ -523,23 +515,31 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: def aggregate( self, - group_by: list[Expr | str] | Expr | str, - aggs: list[Expr] | Expr, + group_by: Sequence[Expr | str] | Expr | str, + aggs: Sequence[Expr] | Expr, ) -> DataFrame: """Aggregates the rows of the current DataFrame. Args: - group_by: List of expressions or column names to group by. - aggs: List of expressions to aggregate. + group_by: Sequence of expressions or column names to group by. + aggs: Sequence of expressions to aggregate. Returns: DataFrame after aggregation. """ - group_by_list = group_by if isinstance(group_by, list) else [group_by] - aggs_list = aggs if isinstance(aggs, list) else [aggs] + group_by_list = ( + list(group_by) + if isinstance(group_by, Sequence) and not isinstance(group_by, (Expr, str)) + else [group_by] + ) + aggs_list = ( + list(aggs) + if isinstance(aggs, Sequence) and not isinstance(aggs, Expr) + else [aggs] + ) group_by_exprs = expr_list_to_raw_expr_list(group_by_list) - aggs_exprs = [ensure_expr(agg) for agg in aggs_list] + aggs_exprs = ensure_expr_list(aggs_list) return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs)) def sort(self, *exprs: SortKey) -> DataFrame: diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 362efd6b1..b97fb3598 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional +from collections.abc import Sequence import pyarrow as pa @@ -131,6 +132,7 @@ WindowExpr = expr_internal.WindowExpr __all__ = [ + "EXPR_TYPE_ERROR", "Aggregate", "AggregateFunction", "Alias", @@ -219,6 +221,7 @@ "WindowFrame", "WindowFrameBound", "ensure_expr", + "ensure_expr_list", ] @@ -243,6 +246,34 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: return value.expr +def ensure_expr_list( + exprs: Iterable[Expr | Iterable[Expr]], +) -> list[expr_internal.Expr]: + """Flatten an iterable of expressions, validating each via ``ensure_expr``. + + Args: + exprs: Possibly nested iterable containing expressions. + + Returns: + A flat list of raw expressions. + + Raises: + TypeError: If any item is not an instance of :class:`Expr`. + """ + + def _iter(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[expr_internal.Expr]: + for expr in items: + if isinstance(expr, Iterable) and not isinstance( + expr, (Expr, str, bytes, bytearray) + ): + # Treat string-like objects as atomic to surface standard errors + yield from _iter(expr) + else: + yield ensure_expr(expr) + + return list(_iter(exprs)) + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index aac091f98..c6de47579 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -303,6 +303,18 @@ def test_aggregate_string_and_expression_equivalent(df): assert result_str == result_expr +def test_aggregate_tuple_group_by(df): + result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict() + result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict() + assert result_tuple == result_list + + +def test_aggregate_tuple_aggs(df): + result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict() + result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict() + assert result_tuple == result_list + + def test_filter_string_unsupported(df): with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.filter("a > 1") @@ -416,14 +428,14 @@ def test_with_columns(df): def test_with_columns_invalid_expr(df): - with pytest.raises( - TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" - ): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.with_columns("a") - with pytest.raises( - TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" - ): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.with_columns(c="a") + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.with_columns(["a"]) + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.with_columns(c=["a"]) def test_cast(df): @@ -843,6 +855,27 @@ def test_window_functions(partitioned_df, name, expr, result): assert table.sort_by("a").to_pydict() == expected +@pytest.mark.parametrize("partition", ["c", df_col("c")]) +def test_rank_partition_by_accepts_string(partitioned_df, partition): + """Passing a string to partition_by should match using col().""" + df = partitioned_df.select( + f.rank(order_by=column("a"), partition_by=partition).alias("r") + ) + table = pa.Table.from_batches(df.sort(column("a")).collect()) + assert table.column("r").to_pylist() == [1, 2, 3, 4, 1, 2, 3] + + +@pytest.mark.parametrize("partition", ["c", df_col("c")]) +def test_window_partition_by_accepts_string(partitioned_df, partition): + """Window.partition_by accepts string identifiers.""" + expr = f.first_value(column("a")).over( + Window(partition_by=partition, order_by=column("b")) + ) + df = partitioned_df.select(expr.alias("fv")) + table = pa.Table.from_batches(df.sort(column("a")).collect()) + assert table.column("fv").to_pylist() == [1, 1, 1, 1, 5, 5, 5] + + @pytest.mark.parametrize( ("units", "start_bound", "end_bound"), [ @@ -913,6 +946,69 @@ def test_window_frame_defaults_match_postgres(partitioned_df): assert df_2.sort(col_a).to_pydict() == expected +def _build_last_value_df(df): + return df.select( + f.last_value(column("a")) + .over( + Window( + partition_by=[column("c")], + order_by=[column("b")], + window_frame=WindowFrame("rows", None, None), + ) + ) + .alias("expr"), + f.last_value(column("a")) + .over( + Window( + partition_by=[column("c")], + order_by="b", + window_frame=WindowFrame("rows", None, None), + ) + ) + .alias("str"), + ) + + +def _build_nth_value_df(df): + return df.select( + f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])).alias("expr"), + f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"), + ) + + +def _build_rank_df(df): + return df.select( + f.rank(order_by=[column("b")]).alias("expr"), + f.rank(order_by="b").alias("str"), + ) + + +def _build_array_agg_df(df): + return df.aggregate( + [column("c")], + [ + f.array_agg(column("a"), order_by=[column("a")]).alias("expr"), + f.array_agg(column("a"), order_by="a").alias("str"), + ], + ).sort(column("c")) + + +@pytest.mark.parametrize( + ("builder", "expected"), + [ + pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6], id="last_value"), + pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7], id="nth_value"), + pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"), + pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]], id="array_agg"), + ], +) +def test_order_by_string_equivalence(partitioned_df, builder, expected): + df = builder(partitioned_df) + table = pa.Table.from_batches(df.collect()) + assert table.column("expr").to_pylist() == expected + assert table.column("expr").to_pylist() == table.column("str").to_pylist() + + def test_html_formatter_cell_dimension(df, clean_formatter_state): """Test configuring the HTML formatter with different options.""" # Configure with custom settings diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index b3e280ed6..810d419cf 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -50,6 +50,7 @@ TransactionStart, Values, ensure_expr, + ensure_expr_list, ) @@ -890,3 +891,18 @@ def test_ensure_expr(): assert ensure_expr(e) is e.expr with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): ensure_expr("a") + + +def test_ensure_expr_list_string(): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr_list("a") + + +def test_ensure_expr_list_bytes(): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr_list(b"a") + + +def test_ensure_expr_list_bytearray(): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr_list(bytearray(b"a")) From 31a648f3636eb44b966eb2a0758d0851e184469e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 17:23:57 +0800 Subject: [PATCH 05/11] refactor: update docstring for sort_or_default function to clarify its purpose --- python/datafusion/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index b97fb3598..902fa2a6a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -309,7 +309,7 @@ def expr_list_to_raw_expr_list( def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: - """Return a :class:`SortExpr`, defaulting attributes when necessary.""" + """Helper function to return a default Sort if an Expr is provided.""" if isinstance(e, SortExpr): return e.raw_sort return SortExpr(e, ascending=True, nulls_first=True).raw_sort From 37307b0ac873df7fe8137192402af7d37d0d3508 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 17:43:25 +0800 Subject: [PATCH 06/11] fix Ruff errors --- python/datafusion/context.py | 8 ++++---- python/datafusion/expr.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 34130ff9c..b6e728b51 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -21,15 +21,14 @@ import warnings from typing import TYPE_CHECKING, Any, Protocol -from collections.abc import Sequence - -import pyarrow as pa try: from warnings import deprecated # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python 3.12 +import pyarrow as pa + from datafusion.catalog import Catalog, CatalogProvider, Table from datafusion.dataframe import DataFrame from datafusion.expr import SortKey, sort_list_to_raw_sort_list @@ -44,9 +43,10 @@ if TYPE_CHECKING: import pathlib + from collections.abc import Sequence import pandas as pd - import polars as pl + import polars as pl # type: ignore[import] from datafusion.plan import ExecutionPlan, LogicalPlan diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 902fa2a6a..4775e9181 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -23,7 +23,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional -from collections.abc import Sequence import pyarrow as pa @@ -38,6 +37,9 @@ from ._internal import functions as functions_internal if TYPE_CHECKING: + from collections.abc import Sequence + + from datafusion.common import DataTypeMap, RexType from datafusion.plan import LogicalPlan From 05cd237da45f7d593433f14f71b3997ae9f8b2bb Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 18:12:13 +0800 Subject: [PATCH 07/11] refactor: update type hints to use typing.Union for better clarity and consistency --- python/datafusion/expr.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 4775e9181..bd01307f7 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional +import typing as _typing +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence import pyarrow as pa @@ -227,7 +228,7 @@ ] -def ensure_expr(value: Expr | Any) -> expr_internal.Expr: +def ensure_expr(value: _typing.Union["Expr", Any]) -> expr_internal.Expr: """Return the internal expression from ``Expr`` or raise ``TypeError``. This helper rejects plain strings and other non-:class:`Expr` values so @@ -249,7 +250,7 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: def ensure_expr_list( - exprs: Iterable[Expr | Iterable[Expr]], + exprs: Iterable[_typing.Union["Expr", Iterable["Expr"]]], ) -> list[expr_internal.Expr]: """Flatten an iterable of expressions, validating each via ``ensure_expr``. @@ -263,7 +264,7 @@ def ensure_expr_list( TypeError: If any item is not an instance of :class:`Expr`. """ - def _iter(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[expr_internal.Expr]: + def _iter(items: Iterable[_typing.Union["Expr", Iterable["Expr"]]]) -> Iterable[expr_internal.Expr]: for expr in items: if isinstance(expr, Iterable) and not isinstance( expr, (Expr, str, bytes, bytearray) @@ -276,7 +277,7 @@ def _iter(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[expr_internal.Expr return list(_iter(exprs)) -def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: +def _to_raw_expr(value: _typing.Union["Expr", str]) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. Args: @@ -300,7 +301,7 @@ def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: def expr_list_to_raw_expr_list( - expr_list: Optional[Sequence[Expr | str] | Expr | str], + expr_list: Optional[_typing.Union[Sequence[_typing.Union["Expr", str]], "Expr", str]], ) -> Optional[list[expr_internal.Expr]]: """Convert a sequence of expressions or column names to raw expressions.""" if isinstance(expr_list, (Expr, str)): @@ -310,7 +311,7 @@ def expr_list_to_raw_expr_list( return [_to_raw_expr(e) for e in expr_list] -def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: +def sort_or_default(e: _typing.Union["Expr", "SortExpr"]) -> expr_internal.SortExpr: """Helper function to return a default Sort if an Expr is provided.""" if isinstance(e, SortExpr): return e.raw_sort @@ -318,7 +319,7 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: def sort_list_to_raw_sort_list( - sort_list: Optional[Sequence[SortKey] | SortKey], + sort_list: Optional[_typing.Union[Sequence["SortKey"], "SortKey"]], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" if isinstance(sort_list, (Expr, SortExpr, str)): @@ -447,7 +448,7 @@ def __invert__(self) -> Expr: """Binary not (~).""" return Expr(self.expr.__invert__()) - def __getitem__(self, key: str | int) -> Expr: + def __getitem__(self, key: _typing.Union[str, int]) -> "Expr": """Retrieve sub-object. If ``key`` is a string, returns the subfield of the struct. @@ -598,13 +599,13 @@ def is_not_null(self) -> Expr: """Returns ``True`` if this expression is not null.""" return Expr(self.expr.is_not_null()) - def fill_nan(self, value: Any | Expr | None = None) -> Expr: + def fill_nan(self, value: Optional[_typing.Union[Any, "Expr"]] = None) -> "Expr": """Fill NaN values with a provided value.""" if not isinstance(value, Expr): value = Expr.literal(value) return Expr(functions_internal.nanvl(self.expr, value.expr)) - def fill_null(self, value: Any | Expr | None = None) -> Expr: + def fill_null(self, value: Optional[_typing.Union[Any, "Expr"]] = None) -> "Expr": """Fill NULL values with a provided value.""" if not isinstance(value, Expr): value = Expr.literal(value) @@ -617,7 +618,7 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr: bool: pa.bool_(), } - def cast(self, to: pa.DataType[Any] | type[float | int | str | bool]) -> Expr: + def cast(self, to: _typing.Union[pa.DataType[Any], type]) -> "Expr": """Cast to a new data type.""" if not isinstance(to, pa.DataType): try: @@ -690,7 +691,7 @@ def column_name(self, plan: LogicalPlan) -> str: """Compute the output column name based on the provided logical plan.""" return self.expr.column_name(plan._raw_plan) - def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder: + def order_by(self, *exprs: _typing.Union["Expr", "SortExpr"]) -> ExprFuncBuilder: """Set the ordering for a window or aggregate function. This function will create an :py:class:`ExprFuncBuilder` that can be used to @@ -1239,9 +1240,9 @@ class Window: def __init__( self, - partition_by: Optional[list[Expr] | Expr] = None, + partition_by: Optional[_typing.Union[list["Expr"], "Expr"]] = None, window_frame: Optional[WindowFrame] = None, - order_by: Optional[list[SortExpr | Expr | str] | Expr | SortExpr | str] = None, + order_by: Optional[_typing.Union[list[_typing.Union["SortExpr", "Expr", str]], "Expr", "SortExpr", str]] = None, null_treatment: Optional[NullTreatment] = None, ) -> None: """Construct a window definition. @@ -1312,7 +1313,7 @@ def __init__(self, frame_bound: expr_internal.WindowFrameBound) -> None: """Constructs a window frame bound.""" self.frame_bound = frame_bound - def get_offset(self) -> int | None: + def get_offset(self) -> Optional[int]: """Returns the offset of the window frame.""" return self.frame_bound.get_offset() @@ -1396,4 +1397,4 @@ def __repr__(self) -> str: return self.raw_sort.__repr__() -SortKey = Expr | SortExpr | str +SortKey = _typing.Union[Expr, SortExpr, str] From 28619d924346875e58177c727e93501ac6c9d0b8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 18:19:54 +0800 Subject: [PATCH 08/11] fix Ruff errors --- python/datafusion/expr.py | 51 ++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index bd01307f7..bdf762691 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -32,7 +32,7 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 -from datafusion.common import DataTypeMap, NullTreatment, RexType +from datafusion.common import NullTreatment from ._internal import expr as expr_internal from ._internal import functions as functions_internal @@ -40,8 +40,16 @@ if TYPE_CHECKING: from collections.abc import Sequence + # These are only imported for type checking to avoid runtime + # evaluation issues with typing constructs. from datafusion.common import DataTypeMap, RexType + # Make the datafusion package available to type checkers for + # fully-qualified string-literal annotations. + import datafusion # type: ignore from datafusion.plan import LogicalPlan +# Note: DataTypeMap and RexType are only available for type checking. +# We intentionally avoid importing them at runtime to prevent evaluation +# issues with complex typing constructs. # Standard error message for invalid expression types @@ -228,7 +236,7 @@ ] -def ensure_expr(value: _typing.Union["Expr", Any]) -> expr_internal.Expr: +def ensure_expr(value: _typing.Union[Expr, Any]) -> expr_internal.Expr: """Return the internal expression from ``Expr`` or raise ``TypeError``. This helper rejects plain strings and other non-:class:`Expr` values so @@ -250,7 +258,7 @@ def ensure_expr(value: _typing.Union["Expr", Any]) -> expr_internal.Expr: def ensure_expr_list( - exprs: Iterable[_typing.Union["Expr", Iterable["Expr"]]], + exprs: Iterable[_typing.Union[Expr, Iterable[Expr]]], ) -> list[expr_internal.Expr]: """Flatten an iterable of expressions, validating each via ``ensure_expr``. @@ -264,7 +272,9 @@ def ensure_expr_list( TypeError: If any item is not an instance of :class:`Expr`. """ - def _iter(items: Iterable[_typing.Union["Expr", Iterable["Expr"]]]) -> Iterable[expr_internal.Expr]: + def _iter( + items: Iterable[_typing.Union[Expr, Iterable[Expr]]], + ) -> Iterable[expr_internal.Expr]: for expr in items: if isinstance(expr, Iterable) and not isinstance( expr, (Expr, str, bytes, bytearray) @@ -277,7 +287,7 @@ def _iter(items: Iterable[_typing.Union["Expr", Iterable["Expr"]]]) -> Iterable[ return list(_iter(exprs)) -def _to_raw_expr(value: _typing.Union["Expr", str]) -> expr_internal.Expr: +def _to_raw_expr(value: _typing.Union[Expr, str]) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. Args: @@ -301,7 +311,7 @@ def _to_raw_expr(value: _typing.Union["Expr", str]) -> expr_internal.Expr: def expr_list_to_raw_expr_list( - expr_list: Optional[_typing.Union[Sequence[_typing.Union["Expr", str]], "Expr", str]], + expr_list: Optional[_typing.Union[Sequence[_typing.Union[Expr, str]], Expr, str]], ) -> Optional[list[expr_internal.Expr]]: """Convert a sequence of expressions or column names to raw expressions.""" if isinstance(expr_list, (Expr, str)): @@ -311,7 +321,7 @@ def expr_list_to_raw_expr_list( return [_to_raw_expr(e) for e in expr_list] -def sort_or_default(e: _typing.Union["Expr", "SortExpr"]) -> expr_internal.SortExpr: +def sort_or_default(e: _typing.Union[Expr, SortExpr]) -> expr_internal.SortExpr: """Helper function to return a default Sort if an Expr is provided.""" if isinstance(e, SortExpr): return e.raw_sort @@ -319,7 +329,7 @@ def sort_or_default(e: _typing.Union["Expr", "SortExpr"]) -> expr_internal.SortE def sort_list_to_raw_sort_list( - sort_list: Optional[_typing.Union[Sequence["SortKey"], "SortKey"]], + sort_list: Optional[_typing.Union[Sequence[SortKey], SortKey]], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" if isinstance(sort_list, (Expr, SortExpr, str)): @@ -448,7 +458,7 @@ def __invert__(self) -> Expr: """Binary not (~).""" return Expr(self.expr.__invert__()) - def __getitem__(self, key: _typing.Union[str, int]) -> "Expr": + def __getitem__(self, key: _typing.Union[str, int]) -> Expr: """Retrieve sub-object. If ``key`` is a string, returns the subfield of the struct. @@ -599,13 +609,13 @@ def is_not_null(self) -> Expr: """Returns ``True`` if this expression is not null.""" return Expr(self.expr.is_not_null()) - def fill_nan(self, value: Optional[_typing.Union[Any, "Expr"]] = None) -> "Expr": + def fill_nan(self, value: Optional[_typing.Union[Any, Expr]] = None) -> Expr: """Fill NaN values with a provided value.""" if not isinstance(value, Expr): value = Expr.literal(value) return Expr(functions_internal.nanvl(self.expr, value.expr)) - def fill_null(self, value: Optional[_typing.Union[Any, "Expr"]] = None) -> "Expr": + def fill_null(self, value: Optional[_typing.Union[Any, Expr]] = None) -> Expr: """Fill NULL values with a provided value.""" if not isinstance(value, Expr): value = Expr.literal(value) @@ -618,7 +628,7 @@ def fill_null(self, value: Optional[_typing.Union[Any, "Expr"]] = None) -> "Expr bool: pa.bool_(), } - def cast(self, to: _typing.Union[pa.DataType[Any], type]) -> "Expr": + def cast(self, to: _typing.Union[pa.DataType[Any], type]) -> Expr: """Cast to a new data type.""" if not isinstance(to, pa.DataType): try: @@ -645,7 +655,7 @@ def between(self, low: Any, high: Any, negated: bool = False) -> Expr: return Expr(self.expr.between(low.expr, high.expr, negated=negated)) - def rex_type(self) -> RexType: + def rex_type(self) -> "datafusion.common.RexType": """Return the Rex Type of this expression. A Rex (Row Expression) specifies a single row of data.That specification @@ -654,7 +664,7 @@ def rex_type(self) -> RexType: """ return self.expr.rex_type() - def types(self) -> DataTypeMap: + def types(self) -> "datafusion.common.DataTypeMap": """Return the ``DataTypeMap``. Returns: @@ -691,7 +701,7 @@ def column_name(self, plan: LogicalPlan) -> str: """Compute the output column name based on the provided logical plan.""" return self.expr.column_name(plan._raw_plan) - def order_by(self, *exprs: _typing.Union["Expr", "SortExpr"]) -> ExprFuncBuilder: + def order_by(self, *exprs: _typing.Union[Expr, SortExpr]) -> ExprFuncBuilder: """Set the ordering for a window or aggregate function. This function will create an :py:class:`ExprFuncBuilder` that can be used to @@ -1240,9 +1250,16 @@ class Window: def __init__( self, - partition_by: Optional[_typing.Union[list["Expr"], "Expr"]] = None, + partition_by: Optional[_typing.Union[list[Expr], Expr]] = None, window_frame: Optional[WindowFrame] = None, - order_by: Optional[_typing.Union[list[_typing.Union["SortExpr", "Expr", str]], "Expr", "SortExpr", str]] = None, + order_by: Optional[ + _typing.Union[ + list[_typing.Union[SortExpr, Expr, str]], + Expr, + SortExpr, + str, + ] + ] = None, null_treatment: Optional[NullTreatment] = None, ) -> None: """Construct a window definition. From 9adbf4f21cc7d8402816fe1816d2b1ac26d25fd5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 18:30:48 +0800 Subject: [PATCH 09/11] refactor: simplify type hints by removing unnecessary imports for type checking --- python/datafusion/expr.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index bdf762691..ba20ba9b6 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -40,16 +40,9 @@ if TYPE_CHECKING: from collections.abc import Sequence - # These are only imported for type checking to avoid runtime - # evaluation issues with typing constructs. + # Type-only imports from datafusion.common import DataTypeMap, RexType - # Make the datafusion package available to type checkers for - # fully-qualified string-literal annotations. - import datafusion # type: ignore from datafusion.plan import LogicalPlan -# Note: DataTypeMap and RexType are only available for type checking. -# We intentionally avoid importing them at runtime to prevent evaluation -# issues with complex typing constructs. # Standard error message for invalid expression types @@ -655,23 +648,23 @@ def between(self, low: Any, high: Any, negated: bool = False) -> Expr: return Expr(self.expr.between(low.expr, high.expr, negated=negated)) - def rex_type(self) -> "datafusion.common.RexType": + def rex_type(self) -> "RexType": # type: ignore[call-arg] """Return the Rex Type of this expression. A Rex (Row Expression) specifies a single row of data.That specification could include user defined functions or types. RexType identifies the row as one of the possible valid ``RexType``. """ - return self.expr.rex_type() + return self.expr.rex_type() # type: ignore - def types(self) -> "datafusion.common.DataTypeMap": + def types(self) -> "DataTypeMap": # type: ignore[call-arg] """Return the ``DataTypeMap``. Returns: DataTypeMap which represents the PythonType, Arrow DataType, and SqlType Enum which this expression represents. """ - return self.expr.types() + return self.expr.types() # type: ignore def python_value(self) -> Any: """Extracts the Expr value into a PyObject. From 0a27617f578b965d5949dc85f9f857bfc8847fd5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 18:32:49 +0800 Subject: [PATCH 10/11] refactor: update type hints for rex_type and types methods to improve clarity --- python/datafusion/expr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index ba20ba9b6..eff0150e8 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -648,23 +648,23 @@ def between(self, low: Any, high: Any, negated: bool = False) -> Expr: return Expr(self.expr.between(low.expr, high.expr, negated=negated)) - def rex_type(self) -> "RexType": # type: ignore[call-arg] + def rex_type(self) -> RexType: # type: ignore[call-arg] """Return the Rex Type of this expression. A Rex (Row Expression) specifies a single row of data.That specification could include user defined functions or types. RexType identifies the row as one of the possible valid ``RexType``. """ - return self.expr.rex_type() # type: ignore + return self.expr.rex_type() - def types(self) -> "DataTypeMap": # type: ignore[call-arg] + def types(self) -> DataTypeMap: # type: ignore[call-arg] """Return the ``DataTypeMap``. Returns: DataTypeMap which represents the PythonType, Arrow DataType, and SqlType Enum which this expression represents. """ - return self.expr.types() # type: ignore + return self.expr.types() def python_value(self) -> Any: """Extracts the Expr value into a PyObject. From 92bc68e284c615441eec69ca84b7b97bc08b327d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 2 Sep 2025 18:36:22 +0800 Subject: [PATCH 11/11] refactor: remove unnecessary type ignore comments from rex_type and types methods --- python/datafusion/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index eff0150e8..3f1544481 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -648,7 +648,7 @@ def between(self, low: Any, high: Any, negated: bool = False) -> Expr: return Expr(self.expr.between(low.expr, high.expr, negated=negated)) - def rex_type(self) -> RexType: # type: ignore[call-arg] + def rex_type(self) -> RexType: """Return the Rex Type of this expression. A Rex (Row Expression) specifies a single row of data.That specification @@ -657,7 +657,7 @@ def rex_type(self) -> RexType: # type: ignore[call-arg] """ return self.expr.rex_type() - def types(self) -> DataTypeMap: # type: ignore[call-arg] + def types(self) -> DataTypeMap: """Return the ``DataTypeMap``. Returns: