From 2dd7ed8b1b32e371bd7f801e034285797b5d74a3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 22 Jan 2025 08:44:50 +0000 Subject: [PATCH 1/7] wip --- narwhals/_expression_parsing.py | 26 ++- narwhals/exceptions.py | 7 + narwhals/expr.py | 377 ++++++++++---------------------- narwhals/functions.py | 76 +++---- narwhals/stable/v1/__init__.py | 28 +-- 5 files changed, 183 insertions(+), 331 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index a11db68eb..809965a3d 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -9,12 +9,12 @@ from typing import Sequence from typing import TypeVar from typing import Union -from typing import cast +from typing import cast, TypedDict from typing import overload from narwhals.dependencies import is_numpy_array from narwhals.exceptions import InvalidIntoExprError -from narwhals.exceptions import LengthChangingExprError +from narwhals.exceptions import LengthChangingExprError, MultiOutputExprError from narwhals.utils import Implementation if TYPE_CHECKING: @@ -43,6 +43,12 @@ T = TypeVar("T") +class ExprMetadata(TypedDict): + is_order_dependent: bool + changes_length: bool + aggregates: bool + is_multi_output: bool + def evaluate_into_expr( df: CompliantDataFrame | CompliantLazyFrame, @@ -383,3 +389,19 @@ def operation_aggregates(*args: IntoExpr | Any) -> bool: # expression does not aggregate, then broadcasting will take # place and the result will not be an aggregate. return all(getattr(x, "_aggregates", True) for x in args) + +def operation_is_multi_output(*args: IntoExpr | Any) -> bool: + # None of the comparands can be multi-output + from narwhals.expr import Expr + + if any(isinstance(x, Expr) and x._metadata['is_multi_output'] for x in args[1:]): + msg = ( + "Multi-output expressions cannot appear in the right-hand-side of\n" + "any operation. For example, `nw.col('a', 'b') + nw.col('c')` is \n" + "allowed, but not `nw.col('a') + nw.col('b', 'c')`." + ) + raise MultiOutputExprError(msg) + return args[0]._metadata['is_multi_output'] + +def combine_metadata(lhs, *args: IntoExpr | Any) -> ExprMetadata: + return ExprMetadata(is_order_dependent=operation_is_order_dependent(lhs, *args), changes_length=operation_changes_length(lhs, *args), aggregates=operation_aggregates(lhs, *args), is_multi_output=operation_is_multi_output(lhs, *args)) \ No newline at end of file diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 6a553fa44..94ba39d26 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -98,6 +98,13 @@ def __init__(self, message: str) -> None: self.message = message super().__init__(self.message) +class MultiOutputExprError(ValueError): + """Exception raised when trying to combine expressions where one has multiple outputs.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + class UnsupportedDTypeError(ValueError): """Exception raised when trying to convert to a DType which is not supported by the given backend.""" diff --git a/narwhals/expr.py b/narwhals/expr.py index ea64a7696..f4cc2030d 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -6,12 +6,12 @@ from typing import Iterable from typing import Literal from typing import Mapping -from typing import Sequence +from typing import Sequence, TypedDict from narwhals._expression_parsing import extract_compliant from narwhals._expression_parsing import operation_aggregates from narwhals._expression_parsing import operation_changes_length -from narwhals._expression_parsing import operation_is_order_dependent +from narwhals._expression_parsing import operation_is_order_dependent, ExprMetadata, combine_metadata from narwhals.dtypes import _validate_dtype from narwhals.expr_cat import ExprCatNamespace from narwhals.expr_dt import ExprDateTimeNamespace @@ -35,22 +35,19 @@ class Expr: def __init__( self, to_compliant_expr: Callable[[Any], Any], - is_order_dependent: bool, # noqa: FBT001 - changes_length: bool, # noqa: FBT001 - aggregates: bool, # noqa: FBT001 + metadata: ExprMetadata ) -> None: # callable from CompliantNamespace to CompliantExpr self._to_compliant_expr = to_compliant_expr - self._is_order_dependent = is_order_dependent - self._changes_length = changes_length - self._aggregates = aggregates + self._metadata = metadata def __repr__(self) -> str: return ( "Narwhals Expr\n" - f"is_order_dependent: {self._is_order_dependent}\n" - f"changes_length: {self._changes_length}\n" - f"aggregates: {self._aggregates}" + f"is_order_dependent: {self._metadata['is_order_dependent']}\n" + f"changes_length: {self._metadata['changes_length']}\n" + f"aggregates: {self._metadata['aggregates']}" + f"is_multi_output: {self._metadata['is_multi_output']}" ) def _taxicab_norm(self) -> Self: @@ -58,9 +55,7 @@ def _taxicab_norm(self) -> Self: # It's not intended to be used. return self.__class__( lambda plx: self._to_compliant_expr(plx).abs().sum(), - self._is_order_dependent, - self._changes_length, - self._aggregates, + self._metadata ) # --- convert --- @@ -117,11 +112,12 @@ def alias(self, name: str) -> Self: c: [[14,15]] """ + if self._metadata['is_multi_output']: + msg = "Cannot alias multi-output expression. Use `.name.suffix`, `.name.map`" + raise ValueError(msg) return self.__class__( lambda plx: self._to_compliant_expr(plx).alias(name), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Self: @@ -243,9 +239,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: _validate_dtype(dtype) return self.__class__( lambda plx: self._to_compliant_expr(plx).cast(dtype), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) # --- binary --- @@ -254,9 +248,7 @@ def __eq__(self, other: object) -> Self: # type: ignore[override] lambda plx: self._to_compliant_expr(plx).__eq__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __ne__(self, other: object) -> Self: # type: ignore[override] @@ -264,9 +256,7 @@ def __ne__(self, other: object) -> Self: # type: ignore[override] lambda plx: self._to_compliant_expr(plx).__ne__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __and__(self, other: Any) -> Self: @@ -274,9 +264,7 @@ def __and__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__and__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rand__(self, other: Any) -> Self: @@ -287,9 +275,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __or__(self, other: Any) -> Self: @@ -297,9 +283,7 @@ def __or__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__or__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __ror__(self, other: Any) -> Self: @@ -310,9 +294,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __add__(self, other: Any) -> Self: @@ -320,9 +302,7 @@ def __add__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__add__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __radd__(self, other: Any) -> Self: @@ -333,9 +313,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __sub__(self, other: Any) -> Self: @@ -343,9 +321,7 @@ def __sub__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__sub__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rsub__(self, other: Any) -> Self: @@ -356,9 +332,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __truediv__(self, other: Any) -> Self: @@ -366,9 +340,7 @@ def __truediv__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__truediv__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rtruediv__(self, other: Any) -> Self: @@ -379,9 +351,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __mul__(self, other: Any) -> Self: @@ -389,9 +359,7 @@ def __mul__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__mul__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rmul__(self, other: Any) -> Self: @@ -402,9 +370,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __le__(self, other: Any) -> Self: @@ -412,9 +378,7 @@ def __le__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__le__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __lt__(self, other: Any) -> Self: @@ -422,9 +386,7 @@ def __lt__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__lt__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __gt__(self, other: Any) -> Self: @@ -432,9 +394,7 @@ def __gt__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__gt__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __ge__(self, other: Any) -> Self: @@ -442,9 +402,7 @@ def __ge__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__ge__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __pow__(self, other: Any) -> Self: @@ -452,9 +410,7 @@ def __pow__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__pow__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rpow__(self, other: Any) -> Self: @@ -465,9 +421,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __floordiv__(self, other: Any) -> Self: @@ -475,9 +429,7 @@ def __floordiv__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__floordiv__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rfloordiv__(self, other: Any) -> Self: @@ -488,9 +440,7 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __mod__(self, other: Any) -> Self: @@ -498,9 +448,7 @@ def __mod__(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__mod__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) def __rmod__(self, other: Any) -> Self: @@ -511,18 +459,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other) ) # --- unary --- def __invert__(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__invert__(), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def any(self) -> Self: @@ -576,9 +520,7 @@ def any(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).any(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def all(self) -> Self: @@ -632,9 +574,7 @@ def all(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).all(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def ewm_mean( @@ -737,9 +677,7 @@ def ewm_mean( min_periods=min_periods, ignore_nulls=ignore_nulls, ), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def mean(self) -> Self: @@ -793,9 +731,7 @@ def mean(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).mean(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def median(self) -> Self: @@ -852,9 +788,7 @@ def median(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).median(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def std(self, *, ddof: int = 1) -> Self: @@ -911,9 +845,7 @@ def std(self, *, ddof: int = 1) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).std(ddof=ddof), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def var(self, *, ddof: int = 1) -> Self: @@ -971,9 +903,7 @@ def var(self, *, ddof: int = 1) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).var(ddof=ddof), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def map_batches( @@ -1050,9 +980,12 @@ def map_batches( function=function, return_dtype=return_dtype ), # safest assumptions - is_order_dependent=True, - changes_length=True, - aggregates=False, + ExprMetadata( + is_order_dependent=True, + changes_length=True, + aggregates=False, + is_multi_output=True + ) ) def skew(self: Self) -> Self: @@ -1106,9 +1039,7 @@ def skew(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).skew(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def sum(self) -> Expr: @@ -1160,9 +1091,7 @@ def sum(self) -> Expr: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).sum(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def min(self) -> Self: @@ -1216,9 +1145,7 @@ def min(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).min(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def max(self) -> Self: @@ -1272,9 +1199,7 @@ def max(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).max(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'aggregates': True}) ) def arg_min(self) -> Self: @@ -1330,9 +1255,12 @@ def arg_min(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_min(), - is_order_dependent=True, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=True, + changes_length=False, + aggregates=True, + is_multi_output=self._metadata['is_multi_output'] + ) ) def arg_max(self) -> Self: @@ -1388,9 +1316,12 @@ def arg_max(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_max(), - is_order_dependent=True, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=True, + changes_length=False, + aggregates=True, + is_multi_output=self._metadata['is_multi_output'] + ) ) def count(self) -> Self: @@ -1444,9 +1375,7 @@ def count(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).count(), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'changes_length': False, 'aggregates': True}) ) def n_unique(self) -> Self: @@ -1498,9 +1427,7 @@ def n_unique(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).n_unique(), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'changes_length': False, 'aggregates': True}) ) def unique(self) -> Self: @@ -1554,9 +1481,7 @@ def unique(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).unique(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True}) ) def abs(self) -> Self: @@ -1612,9 +1537,7 @@ def abs(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).abs(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def cum_sum(self: Self, *, reverse: bool = False) -> Self: @@ -1677,9 +1600,7 @@ def cum_sum(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def diff(self) -> Self: @@ -1748,9 +1669,7 @@ def diff(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).diff(), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def shift(self, n: int) -> Self: @@ -1822,9 +1741,7 @@ def shift(self, n: int) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).shift(n), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def replace_strict( @@ -1917,9 +1834,7 @@ def replace_strict( lambda plx: self._to_compliant_expr(plx).replace_strict( old, new, return_dtype=return_dtype ), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -1950,9 +1865,7 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) # --- transform --- @@ -2027,11 +1940,7 @@ def is_between( extract_compliant(plx, upper_bound), closed, ), - is_order_dependent=operation_is_order_dependent( - self, lower_bound, upper_bound - ), - changes_length=self._changes_length, - aggregates=self._aggregates, + combine_metadata(self, lower_bound, upper_bound), ) def is_in(self, other: Any) -> Self: @@ -2097,9 +2006,7 @@ def is_in(self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).is_in( extract_compliant(plx, other) ), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + combine_metadata(self, other) ) else: msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -2169,9 +2076,7 @@ def filter(self, *predicates: Any) -> Self: lambda plx: self._to_compliant_expr(plx).filter( *[extract_compliant(plx, pred) for pred in flat_predicates], ), - is_order_dependent=operation_is_order_dependent(*flat_predicates), - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**combine_metadata(self, *flat_predicates), 'is_order_dependent': operation_is_order_dependent(*flat_predicates), 'changes_length': True}) ) def is_null(self) -> Self: @@ -2252,9 +2157,7 @@ def is_null(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_null(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def is_nan(self) -> Self: @@ -2322,9 +2225,7 @@ def is_nan(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_nan(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def arg_true(self) -> Self: @@ -2341,9 +2242,7 @@ def arg_true(self) -> Self: issue_deprecation_warning(msg, _version="1.23.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_true(), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def fill_null( @@ -2487,9 +2386,7 @@ def fill_null( lambda plx: self._to_compliant_expr(plx).fill_null( value=value, strategy=strategy, limit=limit ), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) # --- partial reduction --- @@ -2552,9 +2449,7 @@ def drop_nulls(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).drop_nulls(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True}) ) def sample( @@ -2595,9 +2490,7 @@ def sample( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True}) ) def over(self, *keys: str | Iterable[str]) -> Self: @@ -2689,9 +2582,7 @@ def over(self, *keys: str | Iterable[str]) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).over(flatten(keys)), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def is_duplicated(self) -> Self: @@ -2751,9 +2642,7 @@ def is_duplicated(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_duplicated(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def is_unique(self) -> Self: @@ -2813,9 +2702,7 @@ def is_unique(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_unique(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def null_count(self) -> Self: @@ -2874,9 +2761,7 @@ def null_count(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).null_count(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def is_first_distinct(self) -> Self: @@ -2936,9 +2821,7 @@ def is_first_distinct(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_first_distinct(), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def is_last_distinct(self) -> Self: @@ -2998,9 +2881,7 @@ def is_last_distinct(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_last_distinct(), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def quantile( @@ -3071,9 +2952,7 @@ def quantile( """ return self.__class__( lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'changes_length': False, 'aggregates': True}) ) def head(self, n: int = 10) -> Self: @@ -3101,9 +2980,7 @@ def head(self, n: int = 10) -> Self: issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).head(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def tail(self, n: int = 10) -> Self: @@ -3131,9 +3008,7 @@ def tail(self, n: int = 10) -> Self: issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).tail(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def round(self, decimals: int = 0) -> Self: @@ -3201,9 +3076,7 @@ def round(self, decimals: int = 0) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).round(decimals), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def len(self) -> Self: @@ -3263,9 +3136,7 @@ def len(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).len(), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, 'changes_length': False, 'aggregates': True}) ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -3294,9 +3165,7 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) # need to allow numeric typing @@ -3444,11 +3313,7 @@ def clip( extract_compliant(plx, lower_bound), extract_compliant(plx, upper_bound), ), - is_order_dependent=operation_is_order_dependent( - self, lower_bound, upper_bound - ), - changes_length=self._changes_length, - aggregates=self._aggregates, + ExprMetadata({**combine_metadata(self, lower_bound, upper_bound)}) ) def mode(self: Self) -> Self: @@ -3505,9 +3370,7 @@ def mode(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).mode(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True}) ) def is_finite(self: Self) -> Self: @@ -3570,9 +3433,7 @@ def is_finite(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_finite(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata ) def cum_count(self: Self, *, reverse: bool = False) -> Self: @@ -3640,9 +3501,7 @@ def cum_count(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def cum_min(self: Self, *, reverse: bool = False) -> Self: @@ -3710,9 +3569,7 @@ def cum_min(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def cum_max(self: Self, *, reverse: bool = False) -> Self: @@ -3780,9 +3637,7 @@ def cum_max(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def cum_prod(self: Self, *, reverse: bool = False) -> Self: @@ -3850,9 +3705,7 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def rolling_sum( @@ -3947,9 +3800,7 @@ def rolling_sum( min_periods=min_periods, center=center, ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def rolling_mean( @@ -4044,9 +3895,7 @@ def rolling_mean( min_periods=min_periods, center=center, ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def rolling_var( @@ -4141,9 +3990,7 @@ def rolling_var( lambda plx: self._to_compliant_expr(plx).rolling_var( window_size=window_size, min_periods=min_periods, center=center, ddof=ddof ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def rolling_std( @@ -4241,9 +4088,7 @@ def rolling_std( center=center, ddof=ddof, ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def rank( @@ -4341,9 +4186,7 @@ def rank( lambda plx: self._to_compliant_expr(plx).rank( method=method, descending=descending ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) @property diff --git a/narwhals/functions.py b/narwhals/functions.py index a6e370f27..adc0a3d83 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -12,9 +12,9 @@ from typing import Union from typing import overload -from narwhals._expression_parsing import extract_compliant +from narwhals._expression_parsing import extract_compliant, combine_metadata from narwhals._expression_parsing import operation_aggregates -from narwhals._expression_parsing import operation_changes_length +from narwhals._expression_parsing import operation_changes_length, ExprMetadata from narwhals._expression_parsing import operation_is_order_dependent from narwhals._pandas_like.utils import broadcast_align_and_extract_native from narwhals.dataframe import DataFrame @@ -1392,11 +1392,12 @@ def col(*names: str | Iterable[str]) -> Expr: ---- a: [[3,8]] """ + flat_names = flatten(names) def func(plx: Any) -> Any: - return plx.col(*flatten(names)) + return plx.col(*flat_names) - return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) + Expr(func, ExprMetadata(is_order_dependent=False, changes_length=False, aggregates=False, is_multi_output=len(flat_names)>1)) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1454,11 +1455,12 @@ def nth(*indices: int | Sequence[int]) -> Expr: ---- a: [[2,4]] """ + flat_indices = flatten(indices) def func(plx: Any) -> Any: - return plx.nth(*flatten(indices)) + return plx.nth(*flat_indices) - return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) + Expr(func, ExprMetadata(is_order_dependent=False, changes_length=False, aggregates=False, is_multi_output=len(flat_indices)>1)) # Add underscore so it doesn't conflict with builtin `all` @@ -1515,12 +1517,10 @@ def all_() -> Expr: a: [[2,4,6]] b: [[8,10,12]] """ - return Expr( - lambda plx: plx.all(), - is_order_dependent=False, + Expr(lambda plx: plx.all(), ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=False, - ) + aggregates=False, is_multi_output=True)) # Add underscore so it doesn't conflict with builtin `len` @@ -1631,9 +1631,10 @@ def sum(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).sum(), - is_order_dependent=False, + ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=True, + aggregates=True, is_multi_output=len(columns)>1) ) @@ -1692,9 +1693,10 @@ def mean(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).mean(), - is_order_dependent=False, + ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=True, + aggregates=True, is_multi_output=len(columns)>1) ) @@ -1755,9 +1757,10 @@ def median(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).median(), - is_order_dependent=False, + ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=True, + aggregates=True, is_multi_output=len(columns)>1) ) @@ -1816,9 +1819,10 @@ def min(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).min(), - is_order_dependent=False, + ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=True, + aggregates=True, is_multi_output=len(columns)>1) ) @@ -1877,9 +1881,10 @@ def max(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).max(), - is_order_dependent=False, + ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=True, + aggregates=True, is_multi_output=len(columns)>1) ) @@ -1946,9 +1951,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.sum_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) ) @@ -2018,9 +2021,7 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.min_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) ) @@ -2090,9 +2091,7 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.max_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) ) @@ -2284,9 +2283,7 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.all_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) ) @@ -2359,9 +2356,10 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: return Expr( lambda plx: plx.lit(value, dtype), - is_order_dependent=False, + ExprMetadata( + is_order_dependent=False, changes_length=False, - aggregates=True, + aggregates=True, is_multi_output=len(columns)>1) ) @@ -2439,9 +2437,7 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.any_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) ) @@ -2511,9 +2507,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.mean_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index c2f2e7b65..d87e0ae80 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -885,9 +885,7 @@ def head(self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).head(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def tail(self, n: int = 10) -> Self: @@ -901,9 +899,7 @@ def tail(self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).tail(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -918,9 +914,7 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def unique(self, *, maintain_order: bool | None = None) -> Self: @@ -942,9 +936,7 @@ def unique(self, *, maintain_order: bool | None = None) -> Self: warn(message=msg, category=UserWarning, stacklevel=find_stacklevel()) return self.__class__( lambda plx: self._to_compliant_expr(plx).unique(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True}) ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -961,9 +953,7 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) ) def arg_true(self) -> Self: @@ -974,9 +964,7 @@ def arg_true(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_true(), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) def sample( @@ -1010,9 +998,7 @@ def sample( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) ) From d1037c8ee08a227fe6d6f0584458ae7cf12487a7 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 22 Jan 2025 12:01:07 +0000 Subject: [PATCH 2/7] wip --- narwhals/_expression_parsing.py | 36 ++++++--- narwhals/dataframe.py | 4 +- narwhals/expr_cat.py | 6 +- narwhals/expr_dt.py | 88 +++++---------------- narwhals/expr_list.py | 4 +- narwhals/expr_name.py | 24 ++---- narwhals/expr_str.py | 52 ++++--------- narwhals/functions.py | 134 ++++++++++++++++++++------------ narwhals/selectors.py | 55 ++++++++----- narwhals/stable/v1/__init__.py | 39 +++++----- utils/check_api_reference.py | 29 +++---- 11 files changed, 224 insertions(+), 247 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 809965a3d..3bd454900 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -7,14 +7,16 @@ from typing import TYPE_CHECKING from typing import Any from typing import Sequence +from typing import TypedDict from typing import TypeVar from typing import Union -from typing import cast, TypedDict +from typing import cast from typing import overload from narwhals.dependencies import is_numpy_array from narwhals.exceptions import InvalidIntoExprError -from narwhals.exceptions import LengthChangingExprError, MultiOutputExprError +from narwhals.exceptions import LengthChangingExprError +from narwhals.exceptions import MultiOutputExprError from narwhals.utils import Implementation if TYPE_CHECKING: @@ -43,6 +45,7 @@ T = TypeVar("T") + class ExprMetadata(TypedDict): is_order_dependent: bool changes_length: bool @@ -348,7 +351,9 @@ def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: # If an arg is an Expr, we look at `_is_order_dependent`. If it isn't, # it means that it was a scalar (e.g. nw.col('a') + 1) or a column name, # neither of which is order-dependent, so we default to `False`. - return any(getattr(x, "_is_order_dependent", False) for x in args) + from narwhals.expr import Expr + + return any(isinstance(x, Expr) and x._metadata["is_order_dependent"] for x in args) def operation_changes_length(*args: IntoExpr | Any) -> bool: @@ -371,7 +376,9 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: from narwhals.expr import Expr n_exprs = len([x for x in args if isinstance(x, Expr)]) - changes_length = any(isinstance(x, Expr) and x._changes_length for x in args) + changes_length = any( + isinstance(x, Expr) and x._metadata["changes_length"] for x in args + ) if n_exprs > 1 and changes_length: msg = ( "Found multiple expressions at least one of which changes length.\n" @@ -388,20 +395,29 @@ def operation_aggregates(*args: IntoExpr | Any) -> bool: # which is already length-1, so we default to `True`. If any # expression does not aggregate, then broadcasting will take # place and the result will not be an aggregate. - return all(getattr(x, "_aggregates", True) for x in args) + from narwhals.expr import Expr + + return all(isinstance(x, Expr) and x._metadata["aggregates"] for x in args) + def operation_is_multi_output(*args: IntoExpr | Any) -> bool: - # None of the comparands can be multi-output + # Only the first expression is allowed to produce multiple outputs. from narwhals.expr import Expr - if any(isinstance(x, Expr) and x._metadata['is_multi_output'] for x in args[1:]): + if any(isinstance(x, Expr) and x._metadata["is_multi_output"] for x in args[1:]): msg = ( "Multi-output expressions cannot appear in the right-hand-side of\n" "any operation. For example, `nw.col('a', 'b') + nw.col('c')` is \n" "allowed, but not `nw.col('a') + nw.col('b', 'c')`." ) raise MultiOutputExprError(msg) - return args[0]._metadata['is_multi_output'] + return isinstance(args[0], Expr) and args[0]._metadata["is_multi_output"] -def combine_metadata(lhs, *args: IntoExpr | Any) -> ExprMetadata: - return ExprMetadata(is_order_dependent=operation_is_order_dependent(lhs, *args), changes_length=operation_changes_length(lhs, *args), aggregates=operation_aggregates(lhs, *args), is_multi_output=operation_is_multi_output(lhs, *args)) \ No newline at end of file + +def combine_metadata(*args: IntoExpr | Any) -> ExprMetadata: + return ExprMetadata( + is_order_dependent=operation_is_order_dependent(*args), + changes_length=operation_changes_length(*args), + aggregates=operation_aggregates(*args), + is_multi_output=operation_is_multi_output(*args), + ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 28f218274..45cbed559 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3644,7 +3644,7 @@ def _extract_compliant(self, arg: Any) -> Any: msg = "Binary operations between Series and LazyFrame are not supported." raise TypeError(msg) if isinstance(arg, Expr): - if arg._is_order_dependent: + if arg._metadata["is_order_dependent"]: msg = ( "Order-dependent expressions are not supported for use in LazyFrame.\n\n" "Hints:\n" @@ -3655,7 +3655,7 @@ def _extract_compliant(self, arg: Any) -> Any: " they will be supported." ) raise OrderDependentExprError(msg) - if arg._changes_length: + if arg._metadata["changes_length"]: msg = ( "Length-changing expressions are not supported for use in LazyFrame, unless\n" "followed by an aggregation.\n\n" diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index 16dbb3929..3eb09c91d 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -4,6 +4,8 @@ from typing import Generic from typing import TypeVar +from narwhals._expression_parsing import ExprMetadata + if TYPE_CHECKING: from typing_extensions import Self @@ -63,7 +65,5 @@ def get_categories(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories(), - self._expr._is_order_dependent, - changes_length=True, - aggregates=self._expr._aggregates, + ExprMetadata({**self._expr._metadata, "changes_length": True}), # type: ignore[typeddict-item] ) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 6ea1fbbdd..582017dd5 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -71,10 +71,7 @@ def date(self: Self) -> ExprT: a: [[2012-01-07,2023-03-10]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.date(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.date(), self._expr._metadata ) def year(self: Self) -> ExprT: @@ -142,10 +139,7 @@ def year(self: Self) -> ExprT: year: [[1978,2024,2065]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.year(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.year(), self._expr._metadata ) def month(self: Self) -> ExprT: @@ -214,9 +208,7 @@ def month(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.month(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def day(self: Self) -> ExprT: @@ -284,10 +276,7 @@ def day(self: Self) -> ExprT: day: [[1,13,1]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.day(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.day(), self._expr._metadata ) def hour(self: Self) -> ExprT: @@ -355,10 +344,7 @@ def hour(self: Self) -> ExprT: hour: [[1,5,10]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.hour(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.hour(), self._expr._metadata ) def minute(self: Self) -> ExprT: @@ -427,9 +413,7 @@ def minute(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.minute(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def second(self: Self) -> ExprT: @@ -496,9 +480,7 @@ def second(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.second(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def millisecond(self: Self) -> ExprT: @@ -565,9 +547,7 @@ def millisecond(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.millisecond(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def microsecond(self: Self) -> ExprT: @@ -634,9 +614,7 @@ def microsecond(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.microsecond(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def nanosecond(self: Self) -> ExprT: @@ -703,9 +681,7 @@ def nanosecond(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.nanosecond(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def ordinal_day(self: Self) -> ExprT: @@ -764,9 +740,7 @@ def ordinal_day(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.ordinal_day(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def weekday(self: Self) -> ExprT: @@ -823,9 +797,7 @@ def weekday(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.weekday(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_minutes(self: Self) -> ExprT: @@ -889,9 +861,7 @@ def total_minutes(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_minutes(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_seconds(self: Self) -> ExprT: @@ -955,9 +925,7 @@ def total_seconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_seconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_milliseconds(self: Self) -> ExprT: @@ -1026,9 +994,7 @@ def total_milliseconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_milliseconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_microseconds(self: Self) -> ExprT: @@ -1097,9 +1063,7 @@ def total_microseconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_microseconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_nanoseconds(self: Self) -> ExprT: @@ -1155,9 +1119,7 @@ def total_nanoseconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_nanoseconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_string(self: Self, format: str) -> ExprT: # noqa: A002 @@ -1256,9 +1218,7 @@ def to_string(self: Self, format: str) -> ExprT: # noqa: A002 """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.to_string(format), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: @@ -1325,9 +1285,7 @@ def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).dt.replace_time_zone( time_zone ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def convert_time_zone(self: Self, time_zone: str) -> ExprT: @@ -1400,9 +1358,7 @@ def convert_time_zone(self: Self, time_zone: str) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).dt.convert_time_zone( time_zone ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: @@ -1476,7 +1432,5 @@ def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: raise ValueError(msg) return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.timestamp(time_unit), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index 0532db5fe..fc6a1227b 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -74,7 +74,5 @@ def len(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).list.len(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index 706f9427d..975eed7d2 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -60,9 +60,7 @@ def keep(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.keep(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def map(self: Self, function: Callable[[str], str]) -> ExprT: @@ -112,9 +110,7 @@ def map(self: Self, function: Callable[[str], str]) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.map(function), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def prefix(self: Self, prefix: str) -> ExprT: @@ -163,9 +159,7 @@ def prefix(self: Self, prefix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.prefix(prefix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def suffix(self: Self, suffix: str) -> ExprT: @@ -214,9 +208,7 @@ def suffix(self: Self, suffix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.suffix(suffix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_lowercase(self: Self) -> ExprT: @@ -262,9 +254,7 @@ def to_lowercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_lowercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_uppercase(self: Self) -> ExprT: @@ -310,7 +300,5 @@ def to_uppercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_uppercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 90283930a..1618059bf 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -77,9 +77,7 @@ def len_chars(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.len_chars(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def replace( @@ -146,9 +144,7 @@ def replace( lambda plx: self._expr._to_compliant_expr(plx).str.replace( pattern, value, literal=literal, n=n ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def replace_all( @@ -214,9 +210,7 @@ def replace_all( lambda plx: self._expr._to_compliant_expr(plx).str.replace_all( pattern, value, literal=literal ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def strip_chars(self: Self, characters: str | None = None) -> ExprT: @@ -265,9 +259,7 @@ def strip_chars(self: Self, characters: str | None = None) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.strip_chars(characters), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def starts_with(self: Self, prefix: str) -> ExprT: @@ -330,9 +322,7 @@ def starts_with(self: Self, prefix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.starts_with(prefix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def ends_with(self: Self, suffix: str) -> ExprT: @@ -395,9 +385,7 @@ def ends_with(self: Self, suffix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.ends_with(suffix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: @@ -476,9 +464,7 @@ def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).str.contains( pattern, literal=literal ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def slice(self: Self, offset: int, length: int | None = None) -> ExprT: @@ -581,9 +567,7 @@ def slice(self: Self, offset: int, length: int | None = None) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).str.slice( offset=offset, length=length ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def head(self: Self, n: int = 5) -> ExprT: @@ -651,9 +635,7 @@ def head(self: Self, n: int = 5) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.slice(0, n), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def tail(self: Self, n: int = 5) -> ExprT: @@ -723,9 +705,7 @@ def tail(self: Self, n: int = 5) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).str.slice( offset=-n, length=None ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 @@ -795,9 +775,7 @@ def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_datetime(format=format), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_uppercase(self: Self) -> ExprT: @@ -862,9 +840,7 @@ def to_uppercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_uppercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_lowercase(self: Self) -> ExprT: @@ -924,7 +900,5 @@ def to_lowercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_lowercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/functions.py b/narwhals/functions.py index adc0a3d83..f0333b841 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -12,10 +12,9 @@ from typing import Union from typing import overload -from narwhals._expression_parsing import extract_compliant, combine_metadata -from narwhals._expression_parsing import operation_aggregates -from narwhals._expression_parsing import operation_changes_length, ExprMetadata -from narwhals._expression_parsing import operation_is_order_dependent +from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import combine_metadata +from narwhals._expression_parsing import extract_compliant from narwhals._pandas_like.utils import broadcast_align_and_extract_native from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame @@ -1397,7 +1396,15 @@ def col(*names: str | Iterable[str]) -> Expr: def func(plx: Any) -> Any: return plx.col(*flat_names) - Expr(func, ExprMetadata(is_order_dependent=False, changes_length=False, aggregates=False, is_multi_output=len(flat_names)>1)) + return Expr( + func, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=len(flat_names) > 1, + ), + ) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1460,7 +1467,15 @@ def nth(*indices: int | Sequence[int]) -> Expr: def func(plx: Any) -> Any: return plx.nth(*flat_indices) - Expr(func, ExprMetadata(is_order_dependent=False, changes_length=False, aggregates=False, is_multi_output=len(flat_indices)>1)) + return Expr( + func, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=len(flat_indices) > 1, + ), + ) # Add underscore so it doesn't conflict with builtin `all` @@ -1517,10 +1532,15 @@ def all_() -> Expr: a: [[2,4,6]] b: [[8,10,12]] """ - Expr(lambda plx: plx.all(), ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=False, is_multi_output=True)) + return Expr( + lambda plx: plx.all(), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) # Add underscore so it doesn't conflict with builtin `len` @@ -1573,7 +1593,15 @@ def len_() -> Expr: def func(plx: Any) -> Any: return plx.len() - return Expr(func, is_order_dependent=False, changes_length=False, aggregates=True) + return Expr( + func, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=False, + ), + ) def sum(*columns: str) -> Expr: @@ -1631,10 +1659,12 @@ def sum(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).sum(), - ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=True, is_multi_output=len(columns)>1) + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1693,10 +1723,12 @@ def mean(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).mean(), - ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=True, is_multi_output=len(columns)>1) + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1757,10 +1789,12 @@ def median(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).median(), - ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=True, is_multi_output=len(columns)>1) + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1819,10 +1853,12 @@ def min(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).min(), - ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=True, is_multi_output=len(columns)>1) + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1881,10 +1917,12 @@ def max(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).max(), - ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=True, is_multi_output=len(columns)>1) + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1951,7 +1989,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.sum_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) + ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), ) @@ -2021,7 +2059,7 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.min_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) + ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), ) @@ -2091,7 +2129,7 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.max_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) + ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), ) @@ -2116,9 +2154,7 @@ def then(self, value: IntoExpr | Any) -> Then: lambda plx: plx.when(*self._extract_predicates(plx)).then( extract_compliant(plx, value) ), - is_order_dependent=operation_is_order_dependent(*self._predicates, value), - changes_length=operation_changes_length(*self._predicates, value), - aggregates=operation_aggregates(*self._predicates, value), + combine_metadata(*self._predicates, value), ) @@ -2128,9 +2164,7 @@ def otherwise(self, value: IntoExpr | Any) -> Expr: lambda plx: self._to_compliant_expr(plx).otherwise( extract_compliant(plx, value) ), - is_order_dependent=operation_is_order_dependent(self, value), - changes_length=operation_changes_length(self, value), - aggregates=operation_aggregates(self, value), + combine_metadata(self, value), ) @@ -2283,7 +2317,7 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.all_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) + ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), ) @@ -2356,10 +2390,12 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: return Expr( lambda plx: plx.lit(value, dtype), - ExprMetadata( - is_order_dependent=False, - changes_length=False, - aggregates=True, is_multi_output=len(columns)>1) + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=False, + ), ) @@ -2437,7 +2473,7 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.any_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) + ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), ) @@ -2507,7 +2543,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.mean_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), 'is_multi_output': False}) + ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), ) @@ -2597,7 +2633,5 @@ def concat_str( separator=separator, ignore_nulls=ignore_nulls, ), - is_order_dependent=operation_is_order_dependent(*flat_exprs, *more_exprs), - changes_length=operation_changes_length(*flat_exprs, *more_exprs), - aggregates=operation_aggregates(*flat_exprs, *more_exprs), + combine_metadata(*flat_exprs, *more_exprs), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index e67424281..5bf449494 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -2,6 +2,7 @@ from typing import Any +from narwhals._expression_parsing import ExprMetadata from narwhals.expr import Expr from narwhals.utils import flatten @@ -54,9 +55,12 @@ def by_dtype(*dtypes: Any) -> Expr: """ return Selector( lambda plx: plx.selectors.by_dtype(flatten(dtypes)), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) @@ -102,9 +106,12 @@ def numeric() -> Expr: """ return Selector( lambda plx: plx.selectors.numeric(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) @@ -150,9 +157,12 @@ def boolean() -> Expr: """ return Selector( lambda plx: plx.selectors.boolean(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) @@ -198,9 +208,12 @@ def string() -> Expr: """ return Selector( lambda plx: plx.selectors.string(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) @@ -246,9 +259,12 @@ def categorical() -> Expr: """ return Selector( lambda plx: plx.selectors.categorical(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) @@ -294,9 +310,12 @@ def all() -> Expr: """ return Selector( lambda plx: plx.selectors.all(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index d87e0ae80..082be6e4a 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -15,6 +15,7 @@ from narwhals import dependencies from narwhals import exceptions from narwhals import selectors +from narwhals._expression_parsing import ExprMetadata from narwhals.dataframe import DataFrame as NwDataFrame from narwhals.dataframe import LazyFrame as NwLazyFrame from narwhals.dependencies import get_polars @@ -885,7 +886,9 @@ def head(self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).head(n), - ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def tail(self, n: int = 10) -> Self: @@ -899,7 +902,9 @@ def tail(self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).tail(n), - ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -914,7 +919,9 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def unique(self, *, maintain_order: bool | None = None) -> Self: @@ -936,7 +943,7 @@ def unique(self, *, maintain_order: bool | None = None) -> Self: warn(message=msg, category=UserWarning, stacklevel=find_stacklevel()) return self.__class__( lambda plx: self._to_compliant_expr(plx).unique(), - ExprMetadata({**self._metadata, 'changes_length': True}) + ExprMetadata({**self._metadata, "changes_length": True}), ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -953,7 +960,7 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - metadata=ExprMetadata({**self._metadata, 'is_order_dependent': True}) + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def arg_true(self) -> Self: @@ -964,7 +971,9 @@ def arg_true(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_true(), - ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def sample( @@ -998,7 +1007,9 @@ def sample( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), - ExprMetadata({**self._metadata, 'changes_length': True, 'is_order_dependent': True}) + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) @@ -1043,12 +1054,7 @@ def _stableify( level=obj._level, ) if isinstance(obj, NwExpr): - return Expr( - obj._to_compliant_expr, - is_order_dependent=obj._is_order_dependent, - changes_length=obj._changes_length, - aggregates=obj._aggregates, - ) + return Expr(obj._to_compliant_expr, obj._metadata) return obj @@ -2006,12 +2012,7 @@ def then(self, value: Any) -> Then: class Then(NwThen, Expr): @classmethod def from_then(cls, then: NwThen) -> Self: - return cls( - then._to_compliant_expr, - is_order_dependent=then._is_order_dependent, - changes_length=then._changes_length, - aggregates=then._aggregates, - ) + return cls(then._to_compliant_expr, then._metadata) def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index e0ebf97a9..743bf36bf 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -6,6 +6,7 @@ import polars as pl import narwhals as nw +from narwhals._expression_parsing import ExprMetadata from narwhals.utils import remove_prefix from narwhals.utils import remove_suffix @@ -44,6 +45,12 @@ "OrderedDict", "Mapping", } +PLACEHOLDER_EXPR_METADATA = ExprMetadata( + is_order_dependent=False, + aggregates=False, + changes_length=False, + is_multi_output=False, +) files = {remove_suffix(i, ".py") for i in os.listdir("narwhals")} @@ -161,9 +168,7 @@ # Expr methods expr_methods = [ i - for i in nw.Expr( - lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False - ).__dir__() + for i in nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA).__dir__() if not i[0].isupper() and i[0] != "_" ] with open("docs/api-reference/expr.md") as fd: @@ -187,12 +192,7 @@ expr_methods = [ i for i in getattr( - nw.Expr( - lambda: 0, - is_order_dependent=False, - changes_length=False, - aggregates=False, - ), + nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA), namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" @@ -236,9 +236,7 @@ # Check Expr vs Series expr = [ i - for i in nw.Expr( - lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False - ).__dir__() + for i in nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA).__dir__() if not i[0].isupper() and i[0] != "_" ] series = [ @@ -260,12 +258,7 @@ expr_internal = [ i for i in getattr( - nw.Expr( - lambda: 0, - is_order_dependent=False, - changes_length=False, - aggregates=False, - ), + nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA), namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" From 86873dc8ec627a2e3aefeb6358f358f195846890 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 22 Jan 2025 13:53:12 +0000 Subject: [PATCH 3/7] wip --- narwhals/_expression_parsing.py | 12 ++++++++--- narwhals/dataframe.py | 3 ++- narwhals/expr.py | 22 ++++++++++---------- narwhals/functions.py | 3 ++- narwhals/group_by.py | 9 +++++---- narwhals/selectors.py | 36 ++++++++++++++++++++++++++++++++- 6 files changed, 64 insertions(+), 21 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 3bd454900..9a468d633 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -397,21 +397,27 @@ def operation_aggregates(*args: IntoExpr | Any) -> bool: # place and the result will not be an aggregate. from narwhals.expr import Expr - return all(isinstance(x, Expr) and x._metadata["aggregates"] for x in args) + return all(x._metadata["aggregates"] for x in args if isinstance(x, Expr)) def operation_is_multi_output(*args: IntoExpr | Any) -> bool: # Only the first expression is allowed to produce multiple outputs. + # oh shoot - do we need to track the number of outputs? from narwhals.expr import Expr + from narwhals.selectors import Selector - if any(isinstance(x, Expr) and x._metadata["is_multi_output"] for x in args[1:]): + # if all(isinstance(x, Selector) for x in args): + # return True + + n_multi_output = len([x for x in args if isinstance(x, Expr) and x._metadata["is_multi_output"]]) + if n_multi_output > 1: msg = ( "Multi-output expressions cannot appear in the right-hand-side of\n" "any operation. For example, `nw.col('a', 'b') + nw.col('c')` is \n" "allowed, but not `nw.col('a') + nw.col('b', 'c')`." ) raise MultiOutputExprError(msg) - return isinstance(args[0], Expr) and args[0]._metadata["is_multi_output"] + return n_multi_output > 0 def combine_metadata(*args: IntoExpr | Any) -> ExprMetadata: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 45cbed559..e68d7f111 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -1,4 +1,5 @@ from __future__ import annotations +from narwhals.expr import Expr from abc import abstractmethod from typing import TYPE_CHECKING @@ -142,7 +143,7 @@ def filter( ) -> Self: flat_predicates = flatten(predicates) if any( - getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False) + isinstance(x, Expr) and (x._metadata['aggregates'] or x._metadata['changes_length']) for x in flat_predicates ): msg = "Expressions which aggregate or change length cannot be passed to `filter`." diff --git a/narwhals/expr.py b/narwhals/expr.py index f4cc2030d..d045ed54a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -46,7 +46,7 @@ def __repr__(self) -> str: "Narwhals Expr\n" f"is_order_dependent: {self._metadata['is_order_dependent']}\n" f"changes_length: {self._metadata['changes_length']}\n" - f"aggregates: {self._metadata['aggregates']}" + f"aggregates: {self._metadata['aggregates']}\n" f"is_multi_output: {self._metadata['is_multi_output']}" ) @@ -520,7 +520,7 @@ def any(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).any(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def all(self) -> Self: @@ -574,7 +574,7 @@ def all(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).all(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def ewm_mean( @@ -731,7 +731,7 @@ def mean(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).mean(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def median(self) -> Self: @@ -788,7 +788,7 @@ def median(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).median(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def std(self, *, ddof: int = 1) -> Self: @@ -845,7 +845,7 @@ def std(self, *, ddof: int = 1) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).std(ddof=ddof), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def var(self, *, ddof: int = 1) -> Self: @@ -903,7 +903,7 @@ def var(self, *, ddof: int = 1) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).var(ddof=ddof), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def map_batches( @@ -1039,7 +1039,7 @@ def skew(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).skew(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def sum(self) -> Expr: @@ -1091,7 +1091,7 @@ def sum(self) -> Expr: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).sum(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def min(self) -> Self: @@ -1145,7 +1145,7 @@ def min(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).min(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def max(self) -> Self: @@ -1199,7 +1199,7 @@ def max(self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).max(), - ExprMetadata({**self._metadata, 'aggregates': True}) + ExprMetadata({**self._metadata, 'aggregates': True, 'changes_length': False}) ) def arg_min(self) -> Self: diff --git a/narwhals/functions.py b/narwhals/functions.py index f0333b841..67171dbd9 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -2140,8 +2140,9 @@ def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: msg = "At least one predicate needs to be provided to `narwhals.when`." raise TypeError(msg) if any( - getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False) + x._metadata['aggregates'] or x._metadata['changes_length'] for x in self._predicates + if isinstance(x, Expr) ): msg = "Expressions which aggregate or change length cannot be passed to `filter`." raise ShapeError(msg) diff --git a/narwhals/group_by.py b/narwhals/group_by.py index 83d23b3e9..6c347e75c 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -10,6 +10,7 @@ from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame +from narwhals.expr import Expr from narwhals.exceptions import InvalidOperationError from narwhals.utils import tupleify @@ -110,8 +111,8 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - if not all(getattr(x, "_aggregates", True) for x in aggs) and all( - getattr(x, "_aggregates", True) for x in named_aggs.values() + if not all(isinstance(x, Expr) and x._metadata['aggregates'] for x in aggs) and all( + isinstance(x, Expr) and x._metadata['aggregates'] for x in named_aggs.values() ): msg = ( "Found expression which does not aggregate.\n\n" @@ -206,8 +207,8 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - if not all(getattr(x, "_aggregates", True) for x in aggs) and all( - getattr(x, "_aggregates", True) for x in named_aggs.values() + if not all(isinstance(x, Expr) and x._metadata['aggregates'] for x in aggs) and all( + isinstance(x, Expr) and x._metadata['aggregates'] for x in named_aggs.values() ): msg = ( "Found expression which does not aggregate.\n\n" diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 5bf449494..b782b296a 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -1,4 +1,5 @@ from __future__ import annotations +from narwhals._expression_parsing import extract_compliant from typing import Any @@ -7,7 +8,40 @@ from narwhals.utils import flatten -class Selector(Expr): ... +class Selector(Expr): + def __or__(self, other): + return Selector( + lambda plx: self._to_compliant_expr(plx) | extract_compliant(plx, other), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) + def __and__(self, other): + return Selector( + lambda plx: self._to_compliant_expr(plx) & extract_compliant(plx, other), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) + def __sub__(self, other): + return Selector( + lambda plx: self._to_compliant_expr(plx) - extract_compliant(plx, other), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) + + + def by_dtype(*dtypes: Any) -> Expr: From 801756d4e3447b07e16a2be260b16701565623f5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:29:08 +0000 Subject: [PATCH 4/7] wip --- narwhals/group_by.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/narwhals/group_by.py b/narwhals/group_by.py index 6c347e75c..022928eef 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -12,7 +12,7 @@ from narwhals.dataframe import LazyFrame from narwhals.expr import Expr from narwhals.exceptions import InvalidOperationError -from narwhals.utils import tupleify +from narwhals.utils import tupleify, flatten if TYPE_CHECKING: from narwhals.typing import IntoExpr @@ -111,7 +111,8 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - if not all(isinstance(x, Expr) and x._metadata['aggregates'] for x in aggs) and all( + flat_aggs = flatten(aggs) + if not all(isinstance(x, Expr) and x._metadata['aggregates'] for x in flat_aggs) and all( isinstance(x, Expr) and x._metadata['aggregates'] for x in named_aggs.values() ): msg = ( @@ -121,7 +122,7 @@ def agg( "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + aggs, named_aggs = self._df._flatten_and_extract(*flat_aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), ) @@ -207,7 +208,8 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - if not all(isinstance(x, Expr) and x._metadata['aggregates'] for x in aggs) and all( + flat_aggs = flatten(aggs) + if not all(isinstance(x, Expr) and x._metadata['aggregates'] for x in flat_aggs) and all( isinstance(x, Expr) and x._metadata['aggregates'] for x in named_aggs.values() ): msg = ( @@ -217,7 +219,7 @@ def agg( "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + aggs, named_aggs = self._df._flatten_and_extract(*flat_aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), ) From 1123a2e3bbe88d0946d3ce56f1058e277f3e8113 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:52:33 +0000 Subject: [PATCH 5/7] broken --- narwhals/_expression_parsing.py | 57 ++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 9a468d633..387f5e8b3 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -346,14 +346,52 @@ def extract_compliant( return other._compliant_series return other - -def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: - # If an arg is an Expr, we look at `_is_order_dependent`. If it isn't, - # it means that it was a scalar (e.g. nw.col('a') + 1) or a column name, - # neither of which is order-dependent, so we default to `False`. +def arg_aggregates(arg: IntoExpr | Any) -> bool: + from narwhals.expr import Expr + from narwhals.series import Series + if isinstance(arg, Expr): + return arg._metadata['aggregates'] + if isinstance(arg, Series): + return arg.len() == 1 + if isinstance(arg, str): + # Column name, e.g. 'a', gets treated as `nw.col('a')`, + # which doesn't aggregate. + return False + # Scalar + return True + +def arg_changes_length(arg: IntoExpr | Any) -> bool: + from narwhals.expr import Expr + from narwhals.series import Series + if isinstance(arg, Expr): + return arg._metadata['changes_length'] + if isinstance(arg, Series): + return True # safest assumption + if isinstance(arg, str): + # Column name, e.g. 'a', gets treated as `nw.col('a')`, + # which doesn't change length. + return False + # Scalar + return False + +def arg_is_order_dependent(arg: IntoExpr | Any) -> bool: from narwhals.expr import Expr + from narwhals.series import Series + if isinstance(arg, Expr): + return arg._metadata['is_order_dependent'] + if isinstance(arg, Series): + return True # safest assumption + if isinstance(arg, str): + # Column name, e.g. 'a', gets treated as `nw.col('a')`, + # which doesn't change length. + return False + # Scalar + return False + - return any(isinstance(x, Expr) and x._metadata["is_order_dependent"] for x in args) + +def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: + return any(arg_is_order_dependent(x) for x in args) def operation_changes_length(*args: IntoExpr | Any) -> bool: @@ -375,7 +413,8 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: """ from narwhals.expr import Expr - n_exprs = len([x for x in args if isinstance(x, Expr)]) + n_change_length = len([arg_changes_length(x) for x in args]) + n_exprs = len([x for x in args if isinstance(x, (Expr, str))]) changes_length = any( isinstance(x, Expr) and x._metadata["changes_length"] for x in args ) @@ -404,10 +443,6 @@ def operation_is_multi_output(*args: IntoExpr | Any) -> bool: # Only the first expression is allowed to produce multiple outputs. # oh shoot - do we need to track the number of outputs? from narwhals.expr import Expr - from narwhals.selectors import Selector - - # if all(isinstance(x, Selector) for x in args): - # return True n_multi_output = len([x for x in args if isinstance(x, Expr) and x._metadata["is_multi_output"]]) if n_multi_output > 1: From 23509569dcbc095826b76709ffc79d32dc2e054d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:30:32 +0000 Subject: [PATCH 6/7] wip --- narwhals/_expression_parsing.py | 76 +++++++++---------- narwhals/functions.py | 12 +-- narwhals/selectors.py | 13 ++-- narwhals/stable/v1/__init__.py | 2 +- tests/expr_and_series/double_selected_test.py | 8 +- 5 files changed, 55 insertions(+), 56 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 387f5e8b3..b6db9b04c 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -346,39 +346,29 @@ def extract_compliant( return other._compliant_series return other + def arg_aggregates(arg: IntoExpr | Any) -> bool: from narwhals.expr import Expr from narwhals.series import Series + if isinstance(arg, Expr): - return arg._metadata['aggregates'] + return arg._metadata["aggregates"] if isinstance(arg, Series): return arg.len() == 1 - if isinstance(arg, str): + if isinstance(arg, str): # noqa: SIM103 # Column name, e.g. 'a', gets treated as `nw.col('a')`, # which doesn't aggregate. return False # Scalar return True -def arg_changes_length(arg: IntoExpr | Any) -> bool: - from narwhals.expr import Expr - from narwhals.series import Series - if isinstance(arg, Expr): - return arg._metadata['changes_length'] - if isinstance(arg, Series): - return True # safest assumption - if isinstance(arg, str): - # Column name, e.g. 'a', gets treated as `nw.col('a')`, - # which doesn't change length. - return False - # Scalar - return False def arg_is_order_dependent(arg: IntoExpr | Any) -> bool: from narwhals.expr import Expr from narwhals.series import Series + if isinstance(arg, Expr): - return arg._metadata['is_order_dependent'] + return arg._metadata["is_order_dependent"] if isinstance(arg, Series): return True # safest assumption if isinstance(arg, str): @@ -389,11 +379,17 @@ def arg_is_order_dependent(arg: IntoExpr | Any) -> bool: return False - def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: + # If any arg is order-dependent, the whole expression is. return any(arg_is_order_dependent(x) for x in args) +def operation_aggregates(*args: IntoExpr | Any) -> bool: + # If there's a mix of aggregates and non-aggregates, broadcasting + # will happen. The whole operation aggregates if all arguments aggregate. + return all(arg_aggregates(x) for x in args) + + def operation_changes_length(*args: IntoExpr | Any) -> bool: """Track whether operation changes length. @@ -412,12 +408,23 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: scalar, the output changes length """ from narwhals.expr import Expr + from narwhals.series import Series - n_change_length = len([arg_changes_length(x) for x in args]) - n_exprs = len([x for x in args if isinstance(x, (Expr, str))]) - changes_length = any( - isinstance(x, Expr) and x._metadata["changes_length"] for x in args - ) + n_exprs = 0 + changes_length = False + for arg in args: + if isinstance(arg, Expr): + n_exprs += 1 + if arg._metadata["changes_length"]: + changes_length = True + elif isinstance(arg, Series): + n_exprs += 1 + # Safest assumption, although Series are an eager-only + # concept anyway and so the length-changing restrictions + # don't apply to them anyway. + changes_length = True + elif isinstance(arg, str): + n_exprs += 1 if n_exprs > 1 and changes_length: msg = ( "Found multiple expressions at least one of which changes length.\n" @@ -428,37 +435,28 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: return changes_length -def operation_aggregates(*args: IntoExpr | Any) -> bool: - # If an arg is an Expr, we look at `_aggregates`. If it isn't, - # it means that it was a scalar (e.g. nw.col('a').sum() + 1), - # which is already length-1, so we default to `True`. If any - # expression does not aggregate, then broadcasting will take - # place and the result will not be an aggregate. - from narwhals.expr import Expr - - return all(x._metadata["aggregates"] for x in args if isinstance(x, Expr)) - - def operation_is_multi_output(*args: IntoExpr | Any) -> bool: # Only the first expression is allowed to produce multiple outputs. - # oh shoot - do we need to track the number of outputs? from narwhals.expr import Expr - n_multi_output = len([x for x in args if isinstance(x, Expr) and x._metadata["is_multi_output"]]) - if n_multi_output > 1: + if any(isinstance(x, Expr) and x._metadata["is_multi_output"] for x in args[1:]): msg = ( "Multi-output expressions cannot appear in the right-hand-side of\n" "any operation. For example, `nw.col('a', 'b') + nw.col('c')` is \n" "allowed, but not `nw.col('a') + nw.col('b', 'c')`." ) raise MultiOutputExprError(msg) - return n_multi_output > 0 + return isinstance(args[0], Expr) and args[0]._metadata["is_multi_output"] -def combine_metadata(*args: IntoExpr | Any) -> ExprMetadata: +def combine_metadata( + *args: IntoExpr | Any, is_multi_output: bool | None = None +) -> ExprMetadata: return ExprMetadata( is_order_dependent=operation_is_order_dependent(*args), changes_length=operation_changes_length(*args), aggregates=operation_aggregates(*args), - is_multi_output=operation_is_multi_output(*args), + is_multi_output=is_multi_output + if is_multi_output is not None + else operation_is_multi_output(*args), ) diff --git a/narwhals/functions.py b/narwhals/functions.py index feb60bb3d..9e4f28064 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -1990,7 +1990,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.sum_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2060,7 +2060,7 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.min_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2130,7 +2130,7 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.max_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2319,7 +2319,7 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.all_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2475,7 +2475,7 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.any_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2545,7 +2545,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.mean_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - ExprMetadata({**combine_metadata(*flat_exprs), "is_multi_output": False}), + combine_metadata(*exprs, is_multi_output=False), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index b782b296a..51b5936b8 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -1,15 +1,15 @@ from __future__ import annotations -from narwhals._expression_parsing import extract_compliant from typing import Any from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import extract_compliant from narwhals.expr import Expr from narwhals.utils import flatten class Selector(Expr): - def __or__(self, other): + def __or__(self, other: Selector | Any) -> Selector | Any: return Selector( lambda plx: self._to_compliant_expr(plx) | extract_compliant(plx, other), ExprMetadata( @@ -19,7 +19,8 @@ def __or__(self, other): is_multi_output=True, ), ) - def __and__(self, other): + + def __and__(self, other: Selector | Any) -> Selector | Any: return Selector( lambda plx: self._to_compliant_expr(plx) & extract_compliant(plx, other), ExprMetadata( @@ -29,7 +30,8 @@ def __and__(self, other): is_multi_output=True, ), ) - def __sub__(self, other): + + def __sub__(self, other: Selector | Any) -> Selector | Any: return Selector( lambda plx: self._to_compliant_expr(plx) - extract_compliant(plx, other), ExprMetadata( @@ -41,9 +43,6 @@ def __sub__(self, other): ) - - - def by_dtype(*dtypes: Any) -> Expr: """Select columns based on their dtype. diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index ed1365088..b952bb0cd 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -2014,7 +2014,7 @@ def then(self: Self, value: Any) -> Then: class Then(NwThen, Expr): @classmethod def from_then(cls: type, then: NwThen) -> Then: - return cls(then._to_compliant_expr, then._metadata) + return cls(then._to_compliant_expr, then._metadata) # type: ignore[no-any-return] def otherwise(self: Self, value: Any) -> Expr: return _stableify(super().otherwise(value)) diff --git a/tests/expr_and_series/double_selected_test.py b/tests/expr_and_series/double_selected_test.py index a99c90163..862c1e0d3 100644 --- a/tests/expr_and_series/double_selected_test.py +++ b/tests/expr_and_series/double_selected_test.py @@ -1,6 +1,9 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw +from narwhals.exceptions import MultiOutputExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -17,6 +20,5 @@ def test_double_selected(constructor: Constructor) -> None: expected = {"z": [7, 8, 9], "a": [2, 6, 4], "b": [8, 8, 12]} assert_equal_data(result, expected) - result = df.select("a").select(nw.col("a") + nw.all()) - expected = {"a": [2, 6, 4]} - assert_equal_data(result, expected) + with pytest.raises(MultiOutputExprError): + df.select("a").select(nw.col("a") + nw.all()) From 755cc4d644d4f2324e120a985d088495456a19dc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:53:53 +0000 Subject: [PATCH 7/7] simplify --- narwhals/_arrow/selectors.py | 8 -------- narwhals/_dask/selectors.py | 8 -------- narwhals/_expression_parsing.py | 16 +++++----------- narwhals/_pandas_like/selectors.py | 10 ---------- narwhals/selectors.py | 23 +++++++++++++---------- 5 files changed, 18 insertions(+), 47 deletions(-) diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 36feb5d56..84ab1120b 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -169,11 +169,3 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: ) else: return self._to_expr() & other - - def __invert__(self: Self) -> ArrowSelector: - return ( - ArrowSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 9e6cc6302..e084f8dbe 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -177,11 +177,3 @@ def call(df: DaskLazyFrame) -> list[Any]: ) else: return self._to_expr() & other - - def __invert__(self: Self) -> DaskSelector: - return ( - DaskSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index b6db9b04c..e8c2387a5 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -365,17 +365,16 @@ def arg_aggregates(arg: IntoExpr | Any) -> bool: def arg_is_order_dependent(arg: IntoExpr | Any) -> bool: from narwhals.expr import Expr - from narwhals.series import Series if isinstance(arg, Expr): return arg._metadata["is_order_dependent"] - if isinstance(arg, Series): - return True # safest assumption if isinstance(arg, str): # Column name, e.g. 'a', gets treated as `nw.col('a')`, # which doesn't change length. return False - # Scalar + # Scalar or Series + # Series are an eager-only concept anyway and so the order-dependent + # restrictions don't apply to them anyway. return False @@ -408,7 +407,6 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: scalar, the output changes length """ from narwhals.expr import Expr - from narwhals.series import Series n_exprs = 0 changes_length = False @@ -417,14 +415,10 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: n_exprs += 1 if arg._metadata["changes_length"]: changes_length = True - elif isinstance(arg, Series): - n_exprs += 1 - # Safest assumption, although Series are an eager-only - # concept anyway and so the length-changing restrictions - # don't apply to them anyway. - changes_length = True elif isinstance(arg, str): n_exprs += 1 + # Note: Series are an eager-only concept anyway and so the length-changing + # restrictions don't apply to them anyway. if n_exprs > 1 and changes_length: msg = ( "Found multiple expressions at least one of which changes length.\n" diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index b3518283f..4238a6474 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -178,13 +178,3 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) else: return self._to_expr() & other - - def __invert__(self: Self) -> PandasSelector: - return ( - PandasSelectorNamespace( - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).all() - - self - ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 9a66bb4c6..3524bfd43 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -32,9 +32,9 @@ def __rand__(self: Self, other: Any) -> NoReturn: def __ror__(self: Self, other: Any) -> NoReturn: raise NotImplementedError - def __or__(self, other: Selector | Any) -> Selector: + def __and__(self, other: Selector | Any) -> Selector: return Selector( - lambda plx: self._to_compliant_expr(plx) | extract_compliant(plx, other), + lambda plx: self._to_compliant_expr(plx) & extract_compliant(plx, other), ExprMetadata( is_order_dependent=False, changes_length=False, @@ -43,9 +43,9 @@ def __or__(self, other: Selector | Any) -> Selector: ), ) - def __and__(self, other: Selector | Any) -> Selector: + def __or__(self, other: Selector | Any) -> Selector: return Selector( - lambda plx: self._to_compliant_expr(plx) & extract_compliant(plx, other), + lambda plx: self._to_compliant_expr(plx) | extract_compliant(plx, other), ExprMetadata( is_order_dependent=False, changes_length=False, @@ -65,8 +65,11 @@ def __sub__(self, other: Selector | Any) -> Selector: ), ) + def __invert__(self: Self) -> Selector: + return all() - self + -def by_dtype(*dtypes: Any) -> Expr: +def by_dtype(*dtypes: Any) -> Selector: """Select columns based on their dtype. Arguments: @@ -120,7 +123,7 @@ def by_dtype(*dtypes: Any) -> Expr: ) -def numeric() -> Expr: +def numeric() -> Selector: """Select numeric columns. Returns: @@ -171,7 +174,7 @@ def numeric() -> Expr: ) -def boolean() -> Expr: +def boolean() -> Selector: """Select boolean columns. Returns: @@ -222,7 +225,7 @@ def boolean() -> Expr: ) -def string() -> Expr: +def string() -> Selector: """Select string columns. Returns: @@ -273,7 +276,7 @@ def string() -> Expr: ) -def categorical() -> Expr: +def categorical() -> Selector: """Select categorical columns. Returns: @@ -324,7 +327,7 @@ def categorical() -> Expr: ) -def all() -> Expr: +def all() -> Selector: """Select all columns. Returns: