Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: validate expressions shape at narwhals/expr.py level #1845

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import broadcast_and_extract_dataframe_comparand
from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import convert_str_slice_to_int_slice
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import select_rows
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
from narwhals.dependencies import is_numpy_array
from narwhals.utils import Implementation
Expand Down Expand Up @@ -311,12 +311,8 @@ def with_columns(
for col_value in new_columns:
col_name = col_value.name

column = validate_dataframe_comparand(
length=length,
other=col_value,
backend_version=self._backend_version,
allow_broadcast=True,
method_name="with_columns",
column = broadcast_and_extract_dataframe_comparand(
length=length, other=col_value, backend_version=self._backend_version
)

native_frame = (
Expand Down Expand Up @@ -494,12 +490,8 @@ def filter(self: Self, *predicates: IntoArrowExpr, **constraints: Any) -> Self:
)
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]
mask_native = validate_dataframe_comparand(
length=len(self),
other=mask,
backend_version=self._backend_version,
allow_broadcast=False,
method_name="filter",
mask_native = broadcast_and_extract_dataframe_comparand(
length=len(self), other=mask, backend_version=self._backend_version
)
return self._from_native_frame(self._native_frame.filter(mask_native))

Expand Down
16 changes: 1 addition & 15 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals.exceptions import ShapeError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -207,13 +206,10 @@ def broadcast_and_extract_native(
return lhs._native_series, rhs


def validate_dataframe_comparand(
def broadcast_and_extract_dataframe_comparand(
length: int,
other: Any,
backend_version: tuple[int, ...],
*,
allow_broadcast: bool,
method_name: str,
) -> Any:
"""Validate RHS of binary operation.

Expand All @@ -225,23 +221,13 @@ def validate_dataframe_comparand(
if isinstance(other, ArrowSeries):
len_other = len(other)
if len_other == 1:
if length > 1 and not allow_broadcast:
msg = (
f"{method_name}'s length: 1 differs from that of the series: {length}"
)
raise ShapeError(msg)

import numpy as np # ignore-banned-import

value = other._native_series[0]
if backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return pa.array(np.full(shape=length, fill_value=value))

if length != len_other:
msg = f"{method_name}'s length: {len_other} differs from that of the series: {length}"
raise ShapeError(msg)

return other._native_series

from narwhals._arrow.dataframe import ArrowDataFrame # pragma: no cover
Expand Down
23 changes: 7 additions & 16 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import overload

from narwhals._expression_parsing import evaluate_into_exprs
from narwhals._pandas_like.utils import broadcast_and_extract_dataframe_comparand
from narwhals._pandas_like.utils import broadcast_series
from narwhals._pandas_like.utils import convert_str_slice_to_int_slice
from narwhals._pandas_like.utils import create_compliant_series
Expand All @@ -18,7 +19,6 @@
from narwhals._pandas_like.utils import pivot_table
from narwhals._pandas_like.utils import rename
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals.dependencies import is_numpy_array
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
Expand Down Expand Up @@ -417,11 +417,8 @@ def filter(self: Self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> S
)
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]
mask_native = validate_dataframe_comparand(
self._native_frame.index,
mask,
allow_broadcast=False,
method_name="filter",
mask_native = broadcast_and_extract_dataframe_comparand(
self._native_frame.index, mask
)

return self._from_native_frame(self._native_frame.loc[mask_native])
Expand All @@ -442,21 +439,15 @@ def with_columns(
for name in self._native_frame.columns:
if name in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(
index,
new_column_name_to_new_column_map.pop(name),
allow_broadcast=True,
method_name="with_columns",
broadcast_and_extract_dataframe_comparand(
index, new_column_name_to_new_column_map.pop(name)
)
)
else:
to_concat.append(self._native_frame[name])
to_concat.extend(
validate_dataframe_comparand(
index,
new_column_name_to_new_column_map[s],
allow_broadcast=True,
method_name="with_columns",
broadcast_and_extract_dataframe_comparand(
index, new_column_name_to_new_column_map[s]
)
for s in new_column_name_to_new_column_map
)
Expand Down
19 changes: 1 addition & 18 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import TypeVar

from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import ShapeError
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -147,9 +146,7 @@ def broadcast_align_and_extract_native(
return lhs._native_series, rhs


def validate_dataframe_comparand(
index: Any, other: Any, *, allow_broadcast: bool, method_name: str
) -> Any:
def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any:
"""Validate RHS of binary operation.

If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -161,27 +158,13 @@ def validate_dataframe_comparand(
if isinstance(other, PandasLikeDataFrame):
return NotImplemented
if isinstance(other, PandasLikeSeries):
len_index = len(index)
len_other = other.len()

if len_other == 1:
if len_index > 1 and not allow_broadcast:
msg = (
f"{method_name}'s length: 1 differs from that of the series: "
f"{len_index}"
)
raise ShapeError(msg)
# broadcast
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name)

if len_index != len_other:
msg = (
f"{method_name}'s length: {len_other} differs from that of the series: "
f"{len_index}"
)
raise ShapeError(msg)

if other._native_series.index is not index:
return set_index(
other._native_series,
Expand Down
12 changes: 9 additions & 3 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import LengthChangingExprError
from narwhals.exceptions import OrderDependentExprError
from narwhals.exceptions import ShapeError
from narwhals.schema import Schema
from narwhals.translate import to_native
from narwhals.utils import find_stacklevel
Expand Down Expand Up @@ -139,14 +140,19 @@ def drop(self, *columns: Iterable[str], strict: bool) -> Self:
def filter(
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
) -> Self:
flat_predicates = flatten(predicates)
if any(
getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False)
for x in flat_predicates
):
msg = "Expressions which aggregate or change length cannot be passed to `filter`."
raise ShapeError(msg)
if not (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
predicates, constraints = self._flatten_and_extract(
*predicates, **constraints
)
predicates = [self._extract_compliant(v) for v in flat_predicates] # type: ignore[assignment]
return self._from_compliant_dataframe(
self._compliant_frame.filter(*predicates, **constraints),
)
Expand Down
7 changes: 7 additions & 0 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import ShapeError
from narwhals.expr import Expr
from narwhals.translate import from_native
from narwhals.utils import Implementation
Expand Down Expand Up @@ -2101,6 +2102,12 @@ def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
if not self._predicates:
msg = "At least one predicate needs to be provided to `narwhals.when`."
raise TypeError(msg)
if any(
getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False)
for x in self._predicates
):
msg = "Expressions which aggregate or change length cannot be passed to `filter`."
raise ShapeError(msg)

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]
Expand Down
7 changes: 7 additions & 0 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals.exceptions import ShapeError
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data
Expand Down Expand Up @@ -133,6 +134,12 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_when_then_invalid(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
with pytest.raises(ShapeError):
df.select(nw.when(nw.col("a").sum() > 1).then("c"))


def test_when_then_otherwise_lit_str(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z")))
Expand Down
13 changes: 1 addition & 12 deletions tests/frame/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,7 @@ def test_filter_with_boolean_list(constructor: Constructor) -> None:
def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df = nw.from_native(constructor(data))

context = (
pytest.raises(
ShapeError,
match="filter's length: 1 differs from that of the series: 3",
)
if any(x in str(constructor) for x in ("pandas", "pyarrow", "modin"))
else does_not_raise()
if "polars" in str(constructor)
else pytest.raises(Exception) # type: ignore[arg-type] # noqa: PT011
)
with context:
with pytest.raises(ShapeError):
df.filter(nw.col("a").max() > 2).lazy().collect()


Expand Down
Loading