diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index f69485af7..1387db0bd 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -126,6 +126,56 @@ DataFusion's DataFrame API offers a wide range of operations: # Drop columns df = df.drop("temporary_column") +Column Names as Function Arguments +---------------------------------- + +Some ``DataFrame`` methods accept column names when an argument refers to an +existing column. These include: + +* :py:meth:`~datafusion.DataFrame.select` +* :py:meth:`~datafusion.DataFrame.sort` +* :py:meth:`~datafusion.DataFrame.drop` +* :py:meth:`~datafusion.DataFrame.join` (``on`` argument) +* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns) + +See the full function documentation for details on any specific function. + +Note that :py:meth:`~datafusion.DataFrame.join_on` expects ``col()``/``column()`` expressions rather than plain strings. + +For such methods, you can pass column names directly: + +.. code-block:: python + + from datafusion import col, functions as f + + df.sort('id') + df.aggregate('id', [f.count(col('value'))]) + +The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``: + +.. code-block:: python + + from datafusion import col, column, functions as f + + df.sort(col('id')) + df.aggregate(column('id'), [f.count(col('value'))]) + +Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action. + +Whenever an argument represents an expression—such as in +:py:meth:`~datafusion.DataFrame.filter` or +:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference +columns. The comparison and arithmetic operators on ``Expr`` will automatically +convert any non-``Expr`` value into a literal expression, so writing + +.. code-block:: python + + from datafusion import col + df.filter(col("age") > 21) + +is equivalent to using ``lit(21)`` explicitly. Use ``lit()`` (also available +as ``literal()``) when you need to construct a literal expression directly. + Terminal Operations ------------------- diff --git a/python/datafusion/context.py b/python/datafusion/context.py index bce51d644..b6e728b51 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -22,16 +22,16 @@ import warnings from typing import TYPE_CHECKING, Any, Protocol -import pyarrow as pa - try: from warnings import deprecated # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python 3.12 +import pyarrow as pa + from datafusion.catalog import Catalog, CatalogProvider, Table from datafusion.dataframe import DataFrame -from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list +from datafusion.expr import SortKey, sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF @@ -39,12 +39,14 @@ from ._internal import SessionConfig as SessionConfigInternal from ._internal import SessionContext as SessionContextInternal from ._internal import SQLOptions as SQLOptionsInternal +from ._internal import expr as expr_internal if TYPE_CHECKING: import pathlib + from collections.abc import Sequence import pandas as pd - import polars as pl + import polars as pl # type: ignore[import] from datafusion.plan import ExecutionPlan, LogicalPlan @@ -553,7 +555,7 @@ def register_listing_table( table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr]] | None = None, + file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> None: """Register multiple files as a single table. @@ -567,23 +569,20 @@ def register_listing_table( table_partition_cols: Partition columns. file_extension: File extension of the provided table. schema: The data source schema. - file_sort_order: Sort order for the file. + file_sort_order: Sort order for the file. Each sort key can be + specified as a column name (``str``), an expression + (``Expr``), or a ``SortExpr``. """ if table_partition_cols is None: table_partition_cols = [] table_partition_cols = self._convert_table_partition_cols(table_partition_cols) - file_sort_order_raw = ( - [sort_list_to_raw_sort_list(f) for f in file_sort_order] - if file_sort_order is not None - else None - ) self.ctx.register_listing_table( name, str(path), table_partition_cols, file_extension, schema, - file_sort_order_raw, + self._convert_file_sort_order(file_sort_order), ) def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: @@ -808,7 +807,7 @@ def register_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[SortExpr]] | None = None, + file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> None: """Register a Parquet file as a table. @@ -827,7 +826,9 @@ def register_parquet( that may be in the file schema. This can help avoid schema conflicts due to metadata. schema: The data source schema. - file_sort_order: Sort order for the file. + file_sort_order: Sort order for the file. Each sort key can be + specified as a column name (``str``), an expression + (``Expr``), or a ``SortExpr``. """ if table_partition_cols is None: table_partition_cols = [] @@ -840,9 +841,7 @@ def register_parquet( file_extension, skip_metadata, schema, - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None, + self._convert_file_sort_order(file_sort_order), ) def register_csv( @@ -1099,7 +1098,7 @@ def read_parquet( file_extension: str = ".parquet", skip_metadata: bool = True, schema: pa.Schema | None = None, - file_sort_order: list[list[Expr | SortExpr]] | None = None, + file_sort_order: Sequence[Sequence[SortKey]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -1116,7 +1115,9 @@ def read_parquet( schema: An optional schema representing the parquet files. If None, the parquet reader will try to infer it based on data in the file. - file_sort_order: Sort order for the file. + file_sort_order: Sort order for the file. Each sort key can be + specified as a column name (``str``), an expression + (``Expr``), or a ``SortExpr``. Returns: DataFrame representation of the read Parquet files @@ -1124,11 +1125,7 @@ def read_parquet( if table_partition_cols is None: table_partition_cols = [] table_partition_cols = self._convert_table_partition_cols(table_partition_cols) - file_sort_order = ( - [sort_list_to_raw_sort_list(f) for f in file_sort_order] - if file_sort_order is not None - else None - ) + file_sort_order = self._convert_file_sort_order(file_sort_order) return DataFrame( self.ctx.read_parquet( str(path), @@ -1179,6 +1176,24 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions)) + @staticmethod + def _convert_file_sort_order( + file_sort_order: Sequence[Sequence[SortKey]] | None, + ) -> list[list[expr_internal.SortExpr]] | None: + """Convert nested ``SortKey`` sequences into raw sort expressions. + + Each ``SortKey`` can be a column name string, an ``Expr``, or a + ``SortExpr`` and will be converted using + :func:`datafusion.expr.sort_list_to_raw_sort_list`. + """ + # Convert each ``SortKey`` in the provided sort order to the low-level + # representation expected by the Rust bindings. + return ( + [sort_list_to_raw_sort_list(f) for f in file_sort_order] + if file_sort_order is not None + else None + ) + @staticmethod def _convert_table_partition_cols( table_partition_cols: list[tuple[str, str | pa.DataType]], diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 181c29db4..68e6fe5a8 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -22,6 +22,7 @@ from __future__ import annotations import warnings +from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, @@ -40,20 +41,25 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal -from datafusion.expr import Expr, SortExpr, sort_or_default +from datafusion.expr import ( + Expr, + SortKey, + ensure_expr, + ensure_expr_list, + expr_list_to_raw_expr_list, + sort_list_to_raw_sort_list, +) from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream if TYPE_CHECKING: import pathlib - from typing import Callable, Sequence + from typing import Callable import pandas as pd import polars as pl import pyarrow as pa - from datafusion._internal import expr as expr_internal - from enum import Enum @@ -401,9 +407,7 @@ def select(self, *exprs: Expr | str) -> DataFrame: df = df.select("a", col("b"), col("a").alias("alternate_a")) """ - exprs_internal = [ - Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs - ] + exprs_internal = expr_list_to_raw_expr_list(exprs) return DataFrame(self.df.select(*exprs_internal)) def drop(self, *columns: str) -> DataFrame: @@ -421,9 +425,17 @@ def filter(self, *predicates: Expr) -> DataFrame: """Return a DataFrame for which ``predicate`` evaluates to ``True``. Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered - out. If more than one predicate is provided, these predicates will be - combined as a logical AND. If more complex logic is required, see the - logical operations in :py:mod:`~datafusion.functions`. + out. If more than one predicate is provided, these predicates will be + combined as a logical AND. Each ``predicate`` must be an + :class:`~datafusion.expr.Expr` created using helper functions such as + :func:`datafusion.col` or :func:`datafusion.lit`. + If more complex logic is required, see the logical operations in + :py:mod:`~datafusion.functions`. + + Example:: + + from datafusion import col, lit + df.filter(col("a") > lit(1)) Args: predicates: Predicate expression(s) to filter the DataFrame. @@ -433,12 +445,20 @@ def filter(self, *predicates: Expr) -> DataFrame: """ df = self.df for p in predicates: - df = df.filter(p.expr) + df = df.filter(ensure_expr(p)) return DataFrame(df) def with_column(self, name: str, expr: Expr) -> DataFrame: """Add an additional column to the DataFrame. + The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with + :func:`datafusion.col` or :func:`datafusion.lit`. + + Example:: + + from datafusion import col, lit + df.with_column("b", col("a") + lit(1)) + Args: name: Name of the column to add. expr: Expression to compute the column. @@ -446,23 +466,27 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: Returns: DataFrame with the new column. """ - return DataFrame(self.df.with_column(name, expr.expr)) + return DataFrame(self.df.with_column(name, ensure_expr(expr))) def with_columns( self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr ) -> DataFrame: """Add columns to the DataFrame. - By passing expressions, iteratables of expressions, or named expressions. To - pass named expressions use the form name=Expr. + By passing expressions, iterables of expressions, or named expressions. + All expressions must be :class:`~datafusion.expr.Expr` objects created via + :func:`datafusion.col` or :func:`datafusion.lit`. + To pass named expressions use the form ``name=Expr``. - Example usage: The following will add 4 columns labeled a, b, c, and d:: + Example usage: The following will add 4 columns labeled ``a``, ``b``, ``c``, + and ``d``:: + from datafusion import col, lit df = df.with_columns( - lit(0).alias('a'), - [lit(1).alias('b'), lit(2).alias('c')], + col("x").alias("a"), + [lit(1).alias("b"), col("y").alias("c")], d=lit(3) - ) + ) Args: exprs: Either a single expression or an iterable of expressions to add. @@ -471,24 +495,10 @@ def with_columns( Returns: DataFrame with the new columns added. """ - - def _simplify_expression( - *exprs: Expr | Iterable[Expr], **named_exprs: Expr - ) -> list[expr_internal.Expr]: - expr_list = [] - for expr in exprs: - if isinstance(expr, Expr): - expr_list.append(expr.expr) - elif isinstance(expr, Iterable): - expr_list.extend(inner_expr.expr for inner_expr in expr) - else: - raise NotImplementedError - if named_exprs: - for alias, expr in named_exprs.items(): - expr_list.append(expr.alias(alias).expr) - return expr_list - - expressions = _simplify_expression(*exprs, **named_exprs) + expressions = ensure_expr_list(exprs) + for alias, expr in named_exprs.items(): + ensure_expr(expr) + expressions.append(expr.alias(alias).expr) return DataFrame(self.df.with_columns(expressions)) @@ -510,37 +520,47 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: return DataFrame(self.df.with_column_renamed(old_name, new_name)) def aggregate( - self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr + self, + group_by: Sequence[Expr | str] | Expr | str, + aggs: Sequence[Expr] | Expr, ) -> DataFrame: """Aggregates the rows of the current DataFrame. Args: - group_by: List of expressions to group by. - aggs: List of expressions to aggregate. + group_by: Sequence of expressions or column names to group by. + aggs: Sequence of expressions to aggregate. Returns: DataFrame after aggregation. """ - group_by = group_by if isinstance(group_by, list) else [group_by] - aggs = aggs if isinstance(aggs, list) else [aggs] + group_by_list = ( + list(group_by) + if isinstance(group_by, Sequence) and not isinstance(group_by, (Expr, str)) + else [group_by] + ) + aggs_list = ( + list(aggs) + if isinstance(aggs, Sequence) and not isinstance(aggs, Expr) + else [aggs] + ) - group_by = [e.expr for e in group_by] - aggs = [e.expr for e in aggs] - return DataFrame(self.df.aggregate(group_by, aggs)) + group_by_exprs = expr_list_to_raw_expr_list(group_by_list) + aggs_exprs = ensure_expr_list(aggs_list) + return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs)) - def sort(self, *exprs: Expr | SortExpr) -> DataFrame: - """Sort the DataFrame by the specified sorting expressions. + def sort(self, *exprs: SortKey) -> DataFrame: + """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by - calling its` ``sort`` method. + calling its ``sort`` method. Args: - exprs: Sort expressions, applied in order. + exprs: Sort expressions or column names, applied in order. Returns: DataFrame after sorting. """ - exprs_raw = [sort_or_default(expr) for expr in exprs] + exprs_raw = sort_list_to_raw_sort_list(exprs) return DataFrame(self.df.sort(*exprs_raw)) def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: @@ -752,8 +772,14 @@ def join_on( ) -> DataFrame: """Join two :py:class:`DataFrame` using the specified expressions. - On expressions are used to support in-equality predicates. Equality - predicates are correctly optimized + Join predicates must be :class:`~datafusion.expr.Expr` objects, typically + built with :func:`datafusion.col`. On expressions are used to support + in-equality predicates. Equality predicates are correctly optimized. + + Example:: + + from datafusion import col + df.join_on(other_df, col("id") == col("other_id")) Args: right: Other DataFrame to join with. @@ -764,7 +790,7 @@ def join_on( Returns: DataFrame after join. """ - exprs = [expr.expr for expr in on_exprs] + exprs = [ensure_expr(expr) for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) def explain(self, verbose: bool = False, analyze: bool = False) -> None: diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index b51560400..5d1180bd1 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -22,7 +22,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Optional +import typing as _typing +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence import pyarrow as pa @@ -31,14 +32,23 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 -from datafusion.common import DataTypeMap, NullTreatment, RexType +from datafusion.common import NullTreatment from ._internal import expr as expr_internal from ._internal import functions as functions_internal if TYPE_CHECKING: + from collections.abc import Sequence + + # Type-only imports + from datafusion.common import DataTypeMap, RexType from datafusion.plan import LogicalPlan + +# Standard error message for invalid expression types +# Mention both alias forms of column and literal helpers +EXPR_TYPE_ERROR = "Use col()/column() or lit()/literal() to construct expressions" + # The following are imported from the internal representation. We may choose to # give these all proper wrappers, or to simply leave as is. These were added # in order to support passing the `test_imports` unit test. @@ -126,6 +136,7 @@ WindowExpr = expr_internal.WindowExpr __all__ = [ + "EXPR_TYPE_ERROR", "Aggregate", "AggregateFunction", "Alias", @@ -195,6 +206,7 @@ "SimilarTo", "Sort", "SortExpr", + "SortKey", "Subquery", "SubqueryAlias", "TableScan", @@ -212,19 +224,97 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "ensure_expr", + "ensure_expr_list", ] +def ensure_expr(value: _typing.Union[Expr, Any]) -> expr_internal.Expr: + """Return the internal expression from ``Expr`` or raise ``TypeError``. + + This helper rejects plain strings and other non-:class:`Expr` values so + higher level APIs consistently require explicit :func:`~datafusion.col` or + :func:`~datafusion.lit` expressions. + + Args: + value: Candidate expression or other object. + + Returns: + The internal expression representation. + + Raises: + TypeError: If ``value`` is not an instance of :class:`Expr`. + """ + if not isinstance(value, Expr): + raise TypeError(EXPR_TYPE_ERROR) + return value.expr + + +def ensure_expr_list( + exprs: Iterable[_typing.Union[Expr, Iterable[Expr]]], +) -> list[expr_internal.Expr]: + """Flatten an iterable of expressions, validating each via ``ensure_expr``. + + Args: + exprs: Possibly nested iterable containing expressions. + + Returns: + A flat list of raw expressions. + + Raises: + TypeError: If any item is not an instance of :class:`Expr`. + """ + + def _iter( + items: Iterable[_typing.Union[Expr, Iterable[Expr]]], + ) -> Iterable[expr_internal.Expr]: + for expr in items: + if isinstance(expr, Iterable) and not isinstance( + expr, (Expr, str, bytes, bytearray) + ): + # Treat string-like objects as atomic to surface standard errors + yield from _iter(expr) + else: + yield ensure_expr(expr) + + return list(_iter(exprs)) + + +def _to_raw_expr(value: _typing.Union[Expr, str]) -> expr_internal.Expr: + """Convert a Python expression or column name to its raw variant. + + Args: + value: Candidate expression or column name. + + Returns: + The internal :class:`~datafusion._internal.expr.Expr` representation. + + Raises: + TypeError: If ``value`` is neither an :class:`Expr` nor ``str``. + """ + if isinstance(value, str): + return Expr.column(value).expr + if isinstance(value, Expr): + return value.expr + error = ( + "Expected Expr or column name, found:" + f" {type(value).__name__}. {EXPR_TYPE_ERROR}." + ) + raise TypeError(error) + + def expr_list_to_raw_expr_list( expr_list: Optional[list[Expr] | Expr], ) -> Optional[list[expr_internal.Expr]]: - """Helper function to convert an optional list to raw expressions.""" - if isinstance(expr_list, Expr): + """Convert a sequence of expressions or column names to raw expressions.""" + if isinstance(expr_list, (Expr, str)): expr_list = [expr_list] - return [e.expr for e in expr_list] if expr_list is not None else None + if expr_list is None: + return None + return [_to_raw_expr(e) for e in expr_list] -def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: +def sort_or_default(e: _typing.Union[Expr, SortExpr]) -> expr_internal.SortExpr: """Helper function to return a default Sort if an Expr is provided.""" if isinstance(e, SortExpr): return e.raw_sort @@ -232,12 +322,21 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: def sort_list_to_raw_sort_list( - sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr], + sort_list: Optional[_typing.Union[Sequence[SortKey], SortKey]], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" - if isinstance(sort_list, (Expr, SortExpr)): + if isinstance(sort_list, (Expr, SortExpr, str)): sort_list = [sort_list] - return [sort_or_default(e) for e in sort_list] if sort_list is not None else None + if sort_list is None: + return None + raw_sort_list = [] + for item in sort_list: + if isinstance(item, SortExpr): + raw_sort_list.append(sort_or_default(item)) + else: + raw_expr = _to_raw_expr(item) # may raise ``TypeError`` + raw_sort_list.append(sort_or_default(Expr(raw_expr))) + return raw_sort_list class Expr: @@ -352,7 +451,7 @@ def __invert__(self) -> Expr: """Binary not (~).""" return Expr(self.expr.__invert__()) - def __getitem__(self, key: str | int | slice) -> Expr: + def __getitem__(self, key: str | int) -> Expr: """Retrieve sub-object. If ``key`` is a string, returns the subfield of the struct. @@ -530,13 +629,13 @@ def is_not_null(self) -> Expr: """Returns ``True`` if this expression is not null.""" return Expr(self.expr.is_not_null()) - def fill_nan(self, value: Any | Expr | None = None) -> Expr: + def fill_nan(self, value: Optional[_typing.Union[Any, Expr]] = None) -> Expr: """Fill NaN values with a provided value.""" if not isinstance(value, Expr): value = Expr.literal(value) return Expr(functions_internal.nanvl(self.expr, value.expr)) - def fill_null(self, value: Any | Expr | None = None) -> Expr: + def fill_null(self, value: Optional[_typing.Union[Any, Expr]] = None) -> Expr: """Fill NULL values with a provided value.""" if not isinstance(value, Expr): value = Expr.literal(value) @@ -549,7 +648,7 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr: bool: pa.bool_(), } - def cast(self, to: pa.DataType[Any] | type[float | int | str | bool]) -> Expr: + def cast(self, to: _typing.Union[pa.DataType[Any], type]) -> Expr: """Cast to a new data type.""" if not isinstance(to, pa.DataType): try: @@ -622,7 +721,7 @@ def column_name(self, plan: LogicalPlan) -> str: """Compute the output column name based on the provided logical plan.""" return self.expr.column_name(plan._raw_plan) - def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder: + def order_by(self, *exprs: _typing.Union[Expr, SortExpr]) -> ExprFuncBuilder: """Set the ordering for a window or aggregate function. This function will create an :py:class:`ExprFuncBuilder` that can be used to @@ -687,7 +786,7 @@ def over(self, window: Window) -> Expr: window: Window definition """ partition_by_raw = expr_list_to_raw_expr_list(window._partition_by) - order_by_raw = sort_list_to_raw_sort_list(window._order_by) + order_by_raw = window._order_by window_frame_raw = ( window._window_frame.window_frame if window._window_frame is not None @@ -1171,9 +1270,16 @@ class Window: def __init__( self, - partition_by: Optional[list[Expr] | Expr] = None, + partition_by: Optional[_typing.Union[list[Expr], Expr]] = None, window_frame: Optional[WindowFrame] = None, - order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None, + order_by: Optional[ + _typing.Union[ + list[_typing.Union[SortExpr, Expr, str]], + Expr, + SortExpr, + str, + ] + ] = None, null_treatment: Optional[NullTreatment] = None, ) -> None: """Construct a window definition. @@ -1186,7 +1292,7 @@ def __init__( """ self._partition_by = partition_by self._window_frame = window_frame - self._order_by = order_by + self._order_by = sort_list_to_raw_sort_list(order_by) self._null_treatment = null_treatment @@ -1244,7 +1350,7 @@ def __init__(self, frame_bound: expr_internal.WindowFrameBound) -> None: """Constructs a window frame bound.""" self.frame_bound = frame_bound - def get_offset(self) -> int | None: + def get_offset(self) -> Optional[int]: """Returns the offset of the window frame.""" return self.frame_bound.get_offset() @@ -1326,3 +1432,6 @@ def nulls_first(self) -> bool: def __repr__(self) -> str: """Generate a string representation of this expression.""" return self.raw_sort.__repr__() + + +SortKey = _typing.Union[Expr, SortExpr, str] diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 7ee4929a8..648efef79 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -28,6 +28,7 @@ CaseBuilder, Expr, SortExpr, + SortKey, WindowFrame, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, @@ -429,7 +430,7 @@ def window( name: str, args: list[Expr], partition_by: list[Expr] | Expr | None = None, - order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None, + order_by: list[SortKey] | SortKey | None = None, window_frame: WindowFrame | None = None, ctx: SessionContext | None = None, ) -> Expr: @@ -440,6 +441,10 @@ def window( lag use:: df.select(functions.lag(col("a")).partition_by(col("b")).build()) + + The ``order_by`` parameter accepts column names or expressions, e.g.:: + + window("lag", [col("a")], order_by="ts") """ args = [a.expr for a in args] partition_by_raw = expr_list_to_raw_expr_list(partition_by) @@ -1723,7 +1728,7 @@ def array_agg( expression: Expr, distinct: bool = False, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Aggregate values into an array. @@ -1738,7 +1743,11 @@ def array_agg( expression: Values to combine into an array distinct: If True, a single entry for each distinct value will be in the result filter: If provided, only compute against rows for which the filter is True - order_by: Order the resultant array values + order_by: Order the resultant array values. Accepts column names or expressions. + + For example:: + + df.select(array_agg(col("a"), order_by="b")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2222,7 +2231,7 @@ def regr_syy( def first_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the first value in a group of values. @@ -2235,8 +2244,13 @@ def first_value( Args: expression: Argument to perform bitwise calculation on filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. null_treatment: Assign whether to respect or ignore null values. + + For example:: + + df.select(first_value(col("a"), order_by="ts")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2254,7 +2268,7 @@ def first_value( def last_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. @@ -2267,8 +2281,13 @@ def last_value( Args: expression: Argument to perform bitwise calculation on filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. null_treatment: Assign whether to respect or ignore null values. + + For example:: + + df.select(last_value(col("a"), order_by="ts")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2287,7 +2306,7 @@ def nth_value( expression: Expr, n: int, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the n-th value in a group of values. @@ -2301,8 +2320,13 @@ def nth_value( expression: Argument to perform bitwise calculation on n: Index of value to return. Starts at 1. filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. null_treatment: Assign whether to respect or ignore null values. + + For example:: + + df.select(nth_value(col("a"), 2, order_by="ts")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None @@ -2408,7 +2432,7 @@ def lead( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a lead window function. @@ -2437,7 +2461,12 @@ def lead( shift_offset: Number of rows following the current row. default_value: Value to return if shift_offet row does not exist. partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + lead(col("b"), order_by="ts") """ if not isinstance(default_value, pa.Scalar) and default_value is not None: default_value = pa.scalar(default_value) @@ -2461,7 +2490,7 @@ def lag( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a lag window function. @@ -2487,7 +2516,12 @@ def lag( shift_offset: Number of rows before the current row. default_value: Value to return if shift_offet row does not exist. partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + lag(col("b"), order_by="ts") """ if not isinstance(default_value, pa.Scalar): default_value = pa.scalar(default_value) @@ -2508,7 +2542,7 @@ def lag( def row_number( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a row number window function. @@ -2527,7 +2561,12 @@ def row_number( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + row_number(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2542,7 +2581,7 @@ def row_number( def rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a rank window function. @@ -2566,7 +2605,12 @@ def rank( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + rank(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2581,7 +2625,7 @@ def rank( def dense_rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a dense_rank window function. @@ -2600,7 +2644,12 @@ def dense_rank( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + dense_rank(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2615,7 +2664,7 @@ def dense_rank( def percent_rank( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a percent_rank window function. @@ -2635,7 +2684,12 @@ def percent_rank( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + percent_rank(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2650,7 +2704,7 @@ def percent_rank( def cume_dist( partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a cumulative distribution window function. @@ -2670,7 +2724,12 @@ def cume_dist( Args: partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + cume_dist(order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2686,7 +2745,7 @@ def cume_dist( def ntile( groups: int, partition_by: Optional[list[Expr] | Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Create a n-tile window function. @@ -2709,7 +2768,12 @@ def ntile( Args: groups: Number of groups for the n-tile to be divided into. partition_by: Expressions to partition the window frame on. - order_by: Set ordering within the window frame. + order_by: Set ordering within the window frame. Accepts + column names or expressions. + + For example:: + + ntile(3, order_by="points") """ partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) @@ -2727,7 +2791,7 @@ def string_agg( expression: Expr, delimiter: str, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, + order_by: Optional[list[SortKey] | SortKey] = None, ) -> Expr: """Concatenates the input strings. @@ -2742,7 +2806,12 @@ def string_agg( expression: Argument to perform bitwise calculation on delimiter: Text to place between each value of expression filter: If provided, only compute against rows for which the filter is True - order_by: Set the ordering of the expression to evaluate + order_by: Set the ordering of the expression to evaluate. Accepts + column names or expressions. + + For example:: + + df.select(string_agg(col("a"), ",", order_by="b")) """ order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 343d32a92..11317cf3d 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -34,6 +34,9 @@ column, literal, ) +from datafusion import ( + col as df_col, +) from datafusion import ( functions as f, ) @@ -43,7 +46,7 @@ get_formatter, reset_formatter, ) -from datafusion.expr import Window +from datafusion.expr import EXPR_TYPE_ERROR, Window from pyarrow.csv import write_csv MB = 1024 * 1024 @@ -227,6 +230,14 @@ def test_select_mixed_expr_string(df): assert result.column(1) == pa.array([1, 2, 3]) +def test_select_unsupported(df): + with pytest.raises( + TypeError, + match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}", + ): + df.select(1) + + def test_filter(df): df1 = df.filter(column("a") > literal(2)).select( column("a") + column("b"), @@ -268,6 +279,47 @@ def test_sort(df): assert table.to_pydict() == expected +def test_sort_string_and_expression_equivalent(df): + from datafusion import col + + result_str = df.sort("a").to_pydict() + result_expr = df.sort(col("a")).to_pydict() + assert result_str == result_expr + + +def test_sort_unsupported(df): + with pytest.raises( + TypeError, + match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}", + ): + df.sort(1) + + +def test_aggregate_string_and_expression_equivalent(df): + from datafusion import col + + result_str = df.aggregate("a", [f.count()]).sort("a").to_pydict() + result_expr = df.aggregate(col("a"), [f.count()]).sort("a").to_pydict() + assert result_str == result_expr + + +def test_aggregate_tuple_group_by(df): + result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict() + result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict() + assert result_tuple == result_list + + +def test_aggregate_tuple_aggs(df): + result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict() + result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict() + assert result_tuple == result_list + + +def test_filter_string_unsupported(df): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.filter("a > 1") + + def test_drop(df): df = df.drop("c") @@ -337,6 +389,13 @@ def test_with_column(df): assert result.column(2) == pa.array([5, 7, 9]) +def test_with_column_invalid_expr(df): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): + df.with_column("c", "a") + + def test_with_columns(df): df = df.with_columns( (column("a") + column("b")).alias("c"), @@ -368,6 +427,17 @@ def test_with_columns(df): assert result.column(6) == pa.array([5, 7, 9]) +def test_with_columns_invalid_expr(df): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.with_columns("a") + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.with_columns(c="a") + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.with_columns(["a"]) + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + df.with_columns(c=["a"]) + + def test_cast(df): df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())}) expected = pa.schema( @@ -526,6 +596,29 @@ def test_join_on(): assert table.to_pydict() == expected +def test_join_on_invalid_expr(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([4, 5])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], "l") + df1 = ctx.create_dataframe([[batch]], "r") + + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): + df.join_on(df1, "a") + + +def test_aggregate_invalid_aggs(df): + with pytest.raises( + TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" + ): + df.aggregate([], "a") + + def test_distinct(): ctx = SessionContext() @@ -713,6 +806,13 @@ def test_distinct(): ), [1, 1, 1, 1, 5, 5, 5], ), + ( + "first_value_order_by_string", + f.first_value(column("a")).over( + Window(partition_by=[column("c")], order_by="b") + ), + [1, 1, 1, 1, 5, 5, 5], + ), ( "last_value", f.last_value(column("a")).over( @@ -755,6 +855,27 @@ def test_window_functions(partitioned_df, name, expr, result): assert table.sort_by("a").to_pydict() == expected +@pytest.mark.parametrize("partition", ["c", df_col("c")]) +def test_rank_partition_by_accepts_string(partitioned_df, partition): + """Passing a string to partition_by should match using col().""" + df = partitioned_df.select( + f.rank(order_by=column("a"), partition_by=partition).alias("r") + ) + table = pa.Table.from_batches(df.sort(column("a")).collect()) + assert table.column("r").to_pylist() == [1, 2, 3, 4, 1, 2, 3] + + +@pytest.mark.parametrize("partition", ["c", df_col("c")]) +def test_window_partition_by_accepts_string(partitioned_df, partition): + """Window.partition_by accepts string identifiers.""" + expr = f.first_value(column("a")).over( + Window(partition_by=partition, order_by=column("b")) + ) + df = partitioned_df.select(expr.alias("fv")) + table = pa.Table.from_batches(df.sort(column("a")).collect()) + assert table.column("fv").to_pylist() == [1, 1, 1, 1, 5, 5, 5] + + @pytest.mark.parametrize( ("units", "start_bound", "end_bound"), [ @@ -825,6 +946,69 @@ def test_window_frame_defaults_match_postgres(partitioned_df): assert df_2.sort(col_a).to_pydict() == expected +def _build_last_value_df(df): + return df.select( + f.last_value(column("a")) + .over( + Window( + partition_by=[column("c")], + order_by=[column("b")], + window_frame=WindowFrame("rows", None, None), + ) + ) + .alias("expr"), + f.last_value(column("a")) + .over( + Window( + partition_by=[column("c")], + order_by="b", + window_frame=WindowFrame("rows", None, None), + ) + ) + .alias("str"), + ) + + +def _build_nth_value_df(df): + return df.select( + f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])).alias("expr"), + f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"), + ) + + +def _build_rank_df(df): + return df.select( + f.rank(order_by=[column("b")]).alias("expr"), + f.rank(order_by="b").alias("str"), + ) + + +def _build_array_agg_df(df): + return df.aggregate( + [column("c")], + [ + f.array_agg(column("a"), order_by=[column("a")]).alias("expr"), + f.array_agg(column("a"), order_by="a").alias("str"), + ], + ).sort(column("c")) + + +@pytest.mark.parametrize( + ("builder", "expected"), + [ + pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6], id="last_value"), + pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7], id="nth_value"), + pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"), + pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]], id="array_agg"), + ], +) +def test_order_by_string_equivalence(partitioned_df, builder, expected): + df = builder(partitioned_df) + table = pa.Table.from_batches(df.collect()) + assert table.column("expr").to_pylist() == expected + assert table.column("expr").to_pylist() == table.column("str").to_pylist() + + def test_html_formatter_cell_dimension(df, clean_formatter_state): """Test configuring the HTML formatter with different options.""" # Configure with custom settings @@ -2680,3 +2864,34 @@ def test_show_from_empty_batch(capsys) -> None: ctx.create_dataframe([[batch]]).show() out = capsys.readouterr().out assert "| a |" in out + + +@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]]) +def test_register_parquet_file_sort_order(ctx, tmp_path, file_sort_order): + table = pa.table({"a": [1, 2]}) + path = tmp_path / "file.parquet" + pa.parquet.write_table(table, path) + ctx.register_parquet("t", path, file_sort_order=file_sort_order) + assert "t" in ctx.catalog().schema().names() + + +@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]]) +def test_register_listing_table_file_sort_order(ctx, tmp_path, file_sort_order): + table = pa.table({"a": [1, 2]}) + dir_path = tmp_path / "dir" + dir_path.mkdir() + pa.parquet.write_table(table, dir_path / "file.parquet") + ctx.register_listing_table( + "t", dir_path, schema=table.schema, file_sort_order=file_sort_order + ) + assert "t" in ctx.catalog().schema().names() + + +@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]]) +def test_read_parquet_file_sort_order(tmp_path, file_sort_order): + ctx = SessionContext() + table = pa.table({"a": [1, 2]}) + path = tmp_path / "data.parquet" + pa.parquet.write_table(table, path) + df = ctx.read_parquet(path, file_sort_order=file_sort_order) + assert df.collect()[0].column(0).to_pylist() == [1, 2] diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index cfeb07c1f..810d419cf 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import re from datetime import datetime, timezone import pyarrow as pa @@ -28,6 +29,7 @@ literal_with_metadata, ) from datafusion.expr import ( + EXPR_TYPE_ERROR, Aggregate, AggregateFunction, BinaryExpr, @@ -47,6 +49,8 @@ TransactionEnd, TransactionStart, Values, + ensure_expr, + ensure_expr_list, ) @@ -880,3 +884,25 @@ def test_literal_metadata(ctx): for expected_field in expected_schema: actual_field = result[0].schema.field(expected_field.name) assert expected_field.metadata == actual_field.metadata + + +def test_ensure_expr(): + e = col("a") + assert ensure_expr(e) is e.expr + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr("a") + + +def test_ensure_expr_list_string(): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr_list("a") + + +def test_ensure_expr_list_bytes(): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr_list(b"a") + + +def test_ensure_expr_list_bytearray(): + with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): + ensure_expr_list(bytearray(b"a"))