diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index 7cf07d13c7..f9ab23f3b8 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -46,12 +46,8 @@ def agg( *self._keys, *(x for expr in exprs for x in expr(self._compliant_frame)), ] - try: - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.aggregate( - agg_columns, group_expr=",".join(f'"{key}"' for key in self._keys) - ) + return self._compliant_frame._from_native_frame( + self._compliant_frame._native_frame.aggregate( + agg_columns, group_expr=",".join(f'"{key}"' for key in self._keys) ) - except ValueError as exc: # pragma: no cover - msg = "Failed to aggregated - does your aggregation function return a scalar?" - raise RuntimeError(msg) from exc + ) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 51efb1d821..c1f2015310 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -344,11 +344,6 @@ def func(df: Any) -> Any: out_names = [] for expr in exprs: results_keys = expr(from_dataframe(df)) - if not all(len(x) == 1 for x in results_keys): - msg = f"Aggregation '{expr._function_name}' failed to aggregate - does your aggregation function return a scalar? \ - \n\n Please see: https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/" - - raise ValueError(msg) for result_keys in results_keys: out_group.append(result_keys._native_series.iloc[0]) out_names.append(result_keys.name) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 83e7fadf81..4d1c988217 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -162,10 +162,5 @@ def agg_pyspark( ) agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()] - try: - result_simple = grouped.agg(*agg_columns) - except ValueError as exc: # pragma: no cover - msg = "Failed to aggregated - does your aggregation function return a scalar? \ - \n\n Please see: https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/" - raise RuntimeError(msg) from exc + result_simple = grouped.agg(*agg_columns) return from_dataframe(result_simple) diff --git a/narwhals/group_by.py b/narwhals/group_by.py index 76c04fa1f5..83d23b3e99 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -10,6 +10,7 @@ from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame +from narwhals.exceptions import InvalidOperationError from narwhals.utils import tupleify if TYPE_CHECKING: @@ -109,6 +110,16 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ + if not all(getattr(x, "_aggregates", True) for x in aggs) and all( + getattr(x, "_aggregates", True) for x in named_aggs.values() + ): + msg = ( + "Found expression which does not aggregate.\n\n" + "All expressions passed to GroupBy.agg must aggregate.\n" + "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n" + "but `df.group_by('a').agg(nw.col('b'))` is not." + ) + raise InvalidOperationError(msg) aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), @@ -195,6 +206,16 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ + if not all(getattr(x, "_aggregates", True) for x in aggs) and all( + getattr(x, "_aggregates", True) for x in named_aggs.values() + ): + msg = ( + "Found expression which does not aggregate.\n\n" + "All expressions passed to GroupBy.agg must aggregate.\n" + "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n" + "but `df.group_by('a').agg(nw.col('b'))` is not." + ) + raise InvalidOperationError(msg) aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), diff --git a/tests/group_by_test.py b/tests/group_by_test.py index fe797dab7b..3ba35fabc7 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -9,6 +9,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import AnonymousExprError +from narwhals.exceptions import InvalidOperationError from tests.utils import PANDAS_VERSION from tests.utils import PYARROW_VERSION from tests.utils import Constructor @@ -45,7 +46,7 @@ def test_invalid_group_by_dask() -> None: with pytest.raises(ValueError, match=r"Non-trivial complex aggregation found"): nw.from_native(df_dask).group_by("a").agg(nw.col("b").mean().min()) - with pytest.raises(ValueError, match="Non-trivial complex aggregation"): + with pytest.raises(InvalidOperationError, match="does not aggregate"): nw.from_native(df_dask).group_by("a").agg(nw.col("b")) with pytest.raises( @@ -58,7 +59,7 @@ def test_invalid_group_by_dask() -> None: @pytest.mark.filterwarnings("ignore:Found complex group-by expression:UserWarning") def test_invalid_group_by() -> None: df = nw.from_native(df_pandas) - with pytest.raises(ValueError, match="does your"): + with pytest.raises(InvalidOperationError, match="does not aggregate"): df.group_by("a").agg(nw.col("b")) with pytest.raises( AnonymousExprError, @@ -366,25 +367,10 @@ def test_group_by_categorical( assert_equal_data(result, data) -@pytest.mark.filterwarnings("ignore:Found complex group-by expression:UserWarning") -def test_group_by_shift_raises( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor): - # Polars supports all kinds of crazy group-by aggregations, so - # we don't check that it errors here. - request.applymarker(pytest.mark.xfail) - if "cudf" in str(constructor): - # This operation fails completely in cuDF anyway, we just let raise its own - # error. - request.applymarker(pytest.mark.xfail) +def test_group_by_shift_raises(constructor: Constructor) -> None: df_native = {"a": [1, 2, 3], "b": [1, 1, 2]} df = nw.from_native(constructor(df_native)) - with pytest.raises( - ValueError, match=".*(failed to aggregate|Non-trivial complex aggregation found)" - ): + with pytest.raises(InvalidOperationError, match="does not aggregate"): df.group_by("b").agg(nw.col("a").shift(1))