Skip to content

Commit

Permalink
SNOW-1844466: Support more aggregation functions in pivot methods. (#…
Browse files Browse the repository at this point in the history
…2915)

Add support for aggregations ``"count"``, ``"median"``, ``np.median``,
  ``"skew"``, ``"std"``, ``np.std`` ``"var"``, and ``np.var`` in
``pd.pivot_table()``, ``DataFrame.pivot_table()``, and
``pd.crosstab()``.

Snowflake PIVOT now supports all those aggregations.

This commit also expands pivot and crosstab tests to include some
aggregation functions we do not yet support due to Snowflake's PIVOT
limitations.

Fixes SNOW-1844466

---------

Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha authored Jan 28, 2025
1 parent fe13793 commit f61e698
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 123 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
- Added support for `DataFrame.pop` and `Series.pop`.
- Added support for `first` and `last` in `DataFrameGroupBy.agg` and `SeriesGroupBy.agg`.
- Added support for `Index.drop_duplicates`.
- Added support for aggregations `"count"`, `"median"`, `np.median`,
`"skew"`, `"std"`, `np.std` `"var"`, and `np.var` in
`pd.pivot_table()`, `DataFrame.pivot_table()`, and `pd.crosstab()`.

#### Bug Fixes

Expand Down
100 changes: 46 additions & 54 deletions docs/source/modin/supported/agg_supp.rst

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions docs/source/modin/supported/dataframe_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,11 @@ Methods
| | | | any ``argfunc`` is not "count", "mean", "min", |
| | | | "max", or "sum". N if ``index`` is None, |
| | | | ``margins`` is True and ``aggfunc`` is "count" |
| | | | or "mean" or a dictionary. N if ``index`` is None |
| | | | and ``aggfunc`` is a dictionary containing |
| | | | lists of aggfuncs to apply. |
| | | | or "mean" or a dictionary. ``N`` if ``index`` is |
| | | | None and ``aggfunc`` is a dictionary containing |
| | | | lists of aggfuncs to apply. ``N`` if ``aggfunc`` is|
| | | | an `unsupported aggregation |
| | | | function <agg_supp.html>`_ for pivot. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``plot`` | D | | Performed locally on the client |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
8 changes: 4 additions & 4 deletions docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ Data manipulations
| ``concat`` | P | ``levels`` is not supported, | |
| | | ``copy`` is ignored | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not one of |
| | | | "count", "mean", "min", "max", or "sum", or |
| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not a `supported |
| | | | aggregation function <agg_supp.html>`_, |
| | | | margins is True, normalize is "all" or True, |
| | | | and values is passed. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down Expand Up @@ -50,8 +50,8 @@ Data manipulations
| ``pivot`` | P | | See ``pivot_table`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``pivot_table`` | P | ``observed``, ``margins``, | ``N`` if ``index``, ``columns``, or ``values`` is |
| | | ``sort`` | not str; or MultiIndex; or any ``argfunc`` is not |
| | | | "count", "mean", "min", "max", or "sum" |
| | | ``sort`` | not str; or MultiIndex; or any ``aggfunc`` is not a|
| | | | `supported aggregation function <agg_supp.html>`_ |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``qcut`` | P | | ``N`` if ``labels!=False`` or ``retbins=True``. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
38 changes: 35 additions & 3 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ class _SnowparkPandasAggregation(NamedTuple):
# sum would be True.
preserves_snowpark_pandas_types: bool

# Whether Snowflake PIVOT supports this aggregation on axis 0. It seems
# that Snowflake PIVOT supports any aggregation expressed as as single
# function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A)
supported_in_pivot: bool

# This callable takes a single Snowpark column as input and aggregates the
# column on axis=0. If None, Snowpark pandas does not support this
# aggregation on axis=0.
Expand Down Expand Up @@ -305,6 +310,12 @@ class SnowflakeAggFunc(NamedTuple):
# sum would be True.
preserves_snowpark_pandas_types: bool

# Whether Snowflake PIVOT supports this aggregation on axis 0. It seems
# that Snowflake PIVOT supports any aggregation expressed as as single
# function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A).
# This field only makes sense for axis 0 aggregation.
supported_in_pivot: bool


class AggFuncWithLabel(NamedTuple):
"""
Expand Down Expand Up @@ -523,6 +534,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
axis_0_aggregation=count,
axis_1_aggregation_skipna=_columns_count,
preserves_snowpark_pandas_types=False,
supported_in_pivot=True,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
(len, "size"),
Expand All @@ -532,47 +544,53 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
axis_1_aggregation_keepna=_columns_count_keep_nulls,
axis_1_aggregation_skipna=_columns_count_keep_nulls,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
),
"first": _SnowparkPandasAggregation(
axis_0_aggregation=_column_first_value,
axis_1_aggregation_keepna=lambda *cols: cols[0],
axis_1_aggregation_skipna=lambda *cols: coalesce(*cols),
preserves_snowpark_pandas_types=True,
supported_in_pivot=False,
),
"last": _SnowparkPandasAggregation(
axis_0_aggregation=_column_last_value,
axis_1_aggregation_keepna=lambda *cols: cols[-1],
axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])),
preserves_snowpark_pandas_types=True,
supported_in_pivot=False,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("mean", np.mean),
_SnowparkPandasAggregation(
axis_0_aggregation=mean,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("min", np.min),
("min", np.min, min),
_SnowparkPandasAggregation(
axis_0_aggregation=min_,
axis_1_aggregation_keepna=least,
axis_1_aggregation_skipna=_columns_coalescing_min,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("max", np.max),
("max", np.max, max),
_SnowparkPandasAggregation(
axis_0_aggregation=max_,
axis_1_aggregation_keepna=greatest,
axis_1_aggregation_skipna=_columns_coalescing_max,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("sum", np.sum),
("sum", np.sum, sum),
_SnowparkPandasAggregation(
axis_0_aggregation=sum_,
# IMPORTANT: count and sum use python builtin sum to invoke
Expand All @@ -581,13 +599,15 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
axis_1_aggregation_keepna=lambda *cols: sum(cols),
axis_1_aggregation_skipna=_columns_coalescing_sum,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("median", np.median),
_SnowparkPandasAggregation(
axis_0_aggregation=median,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
"idxmax": _SnowparkPandasAggregation(
Expand All @@ -597,6 +617,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"idxmin": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
Expand All @@ -605,30 +626,35 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"skew": _SnowparkPandasAggregation(
axis_0_aggregation=skew,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
"all": _SnowparkPandasAggregation(
# all() for a column with no non-null values is NULL in Snowflake, but True in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("booland_agg")(col(c)), pandas_lit(True)
),
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"any": _SnowparkPandasAggregation(
# any() for a column with no non-null values is NULL in Snowflake, but False in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("boolor_agg")(col(c)), pandas_lit(False)
),
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("std", np.std),
_SnowparkPandasAggregation(
axis_0_aggregation=stddev,
preserves_snowpark_pandas_types=True,
supported_in_pivot=True,
),
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
Expand All @@ -638,19 +664,23 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
# variance units are the square of the input column units, so
# variance does not preserve types.
preserves_snowpark_pandas_types=False,
supported_in_pivot=True,
),
),
"array_agg": _SnowparkPandasAggregation(
axis_0_aggregation=array_agg,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
"quantile": _SnowparkPandasAggregation(
axis_0_aggregation=column_quantile,
preserves_snowpark_pandas_types=True,
supported_in_pivot=False,
),
"nunique": _SnowparkPandasAggregation(
axis_0_aggregation=count_distinct,
preserves_snowpark_pandas_types=False,
supported_in_pivot=False,
),
}
)
Expand Down Expand Up @@ -762,6 +792,7 @@ def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
return SnowflakeAggFunc(
snowpark_aggregation=snowpark_aggregation,
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot,
)


Expand Down Expand Up @@ -800,6 +831,7 @@ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn:
return SnowflakeAggFunc(
snowpark_aggregation,
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot,
)


Expand Down
5 changes: 1 addition & 4 deletions src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
extract_pandas_label_from_snowflake_quoted_identifier,
from_pandas_label,
get_distinct_rows,
is_supported_snowflake_pivot_agg_func,
pandas_lit,
random_name_for_temp_object,
to_pandas_label,
Expand Down Expand Up @@ -522,9 +521,7 @@ def single_pivot_helper(
data_column_pandas_labels: new data column pandas labels for this pivot result
"""
snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0)
if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func(
snowflake_agg_func.snowpark_aggregation
):
if snowflake_agg_func is None or not snowflake_agg_func.supported_in_pivot:
# TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations
raise ErrorMessage.not_implemented(
f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments."
Expand Down
30 changes: 1 addition & 29 deletions src/snowflake/snowpark/modin/plugin/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import traceback
from collections.abc import Hashable, Iterable, Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import modin.pandas as pd
import numpy as np
Expand Down Expand Up @@ -42,14 +42,9 @@
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import (
col,
count,
equal_nan,
floor,
iff,
max as max_,
mean,
min as min_,
sum as sum_,
to_char,
to_timestamp_ntz,
to_timestamp_tz,
Expand Down Expand Up @@ -1196,29 +1191,6 @@ def is_snowpark_pandas_dataframe_or_series_type(obj: Any) -> bool:
return isinstance(obj, (pd.DataFrame, pd.Series))


# TODO: (SNOW-853334) Support other agg functions (any, all, prod, median, skew, kurt, sem, var, std, mad, etc)
snowflake_pivot_agg_func_supported = [
count,
mean,
min_,
max_,
sum_,
]


def is_supported_snowflake_pivot_agg_func(agg_func: Callable) -> bool:
"""
Check if the aggregation function is supported with snowflake pivot. Current supported
aggregation functions are the functions that can be mapped to snowflake builtin function.
Args:
agg_func: str or Callable. the aggregation function to check
Returns:
Whether it is valid to implement with snowflake or not.
"""
return agg_func in snowflake_pivot_agg_func_supported


def convert_snowflake_string_constant_to_python_string(identifier: str) -> str:
"""
Convert a snowflake string constant to a python constant, this removes surrounding single quotes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9094,12 +9094,6 @@ def pivot_table(
if not sort:
raise NotImplementedError("Not implemented not sorted")

# TODO: (SNOW-853334) Support callable agg functions
if aggfunc and callable(aggfunc):
raise NotImplementedError(
f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(aggfunc, agg_kwargs={})} with the given arguments."
)

if columns is not None and isinstance(columns, Hashable):
columns = [columns]

Expand Down
Loading

0 comments on commit f61e698

Please sign in to comment.