diff --git a/CHANGELOG.md b/CHANGELOG.md index 07e7956961f..adfa193c1a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,9 @@ - Added support for `DataFrame.pop` and `Series.pop`. - Added support for `first` and `last` in `DataFrameGroupBy.agg` and `SeriesGroupBy.agg`. - Added support for `Index.drop_duplicates`. +- Added support for aggregations `"count"`, `"median"`, `np.median`, + `"skew"`, `"std"`, `np.std` `"var"`, and `np.var` in + `pd.pivot_table()`, `DataFrame.pivot_table()`, and `pd.crosstab()`. #### Bug Fixes diff --git a/docs/source/modin/supported/agg_supp.rst b/docs/source/modin/supported/agg_supp.rst index 1cc58e8d4e5..a16c0be2512 100644 --- a/docs/source/modin/supported/agg_supp.rst +++ b/docs/source/modin/supported/agg_supp.rst @@ -3,13 +3,9 @@ Supported Aggregation Functions ==================================== -This page lists which aggregation functions are supported by ``DataFrame.agg``, -``Series.agg``, ``DataFrameGroupBy.agg``, and ``SeriesGroupBy.agg``. -The following table is structured as follows: The first column contains the aggregation function's name. -The second column is a flag for whether or not the aggregation is supported by ``DataFrame.agg``. The -third column is a flag for whether or not the aggregation is supported by ``Series.agg``. The fourth column -is whether or not the aggregation is supported by ``DataFrameGroupBy.agg``. The fifth column is whether or not -the aggregation is supported by ``SeriesGroupBy.agg``. +This page lists which aggregation functions are supported by ``DataFrame.agg``; +``Series.agg``; ``DataFrameGroupBy.agg``; ``SeriesGroupBy.agg``; and the pivot +methods ``pd.pivot_table``, ``DataFrame.pivot_table``, and ``pd.crosstab``. .. note:: ``Y`` stands for yes (supports distributed implementation), ``N`` stands for no (API simply errors out), @@ -17,50 +13,46 @@ the aggregation is supported by ``SeriesGroupBy.agg``. Both Python builtin and NumPy functions are supported for ``DataFrameGroupBy.agg`` and ``SeriesGroupBy.agg``. -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| Aggregation Function | ``DataFrame.agg`` supports? (Y/N/P) | ``Series.agg`` supports? (Y/N/P) | ``DataFrameGroupBy.agg`` supports? (Y/N/P) | ``SeriesGroupBy.agg`` supports? (Y/N/P) | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``count`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | For ``axis=1``, ``Y`` if index is | | | | -| | not a MultiIndex. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``mean`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | ``N`` for ``axis=1``. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``min`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | For ``axis=1``, ``Y`` if index is | | | | -| | not a MultiIndex. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``max`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | For ``axis=1``, ``Y`` if index is | | | | -| | not a MultiIndex. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``sum`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | For ``axis=1``, ``Y`` if index is | | | | -| | not a MultiIndex. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``median`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | ``N`` for ``axis=1``. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``size`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | -| | ``N`` for ``axis=1``. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``std`` | ``P`` for ``axis=0`` - only when | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | -| | ``ddof=0`` or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | -| | ``N`` for ``axis=1``. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``var`` | ``P`` for ``axis=0`` - only when | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | -| | ``ddof=0`` or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | -| | ``N`` for ``axis=1``. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``quantile`` | ``P`` for ``axis=0`` - only when | ``P`` - only when ``q`` is the | ``P`` - only when ``q`` is the | ``P`` - only when ``q`` is the | -| | ``q`` is the default value or | default value or a scalar. | default value or a scalar. | default value or a scalar. | -| | a scalar. | | | | -| | ``N`` for ``axis=1``. | | | | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``len`` | ``N`` | ``N`` | ``Y`` | ``Y`` | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``first`` | ``N`` | ``N`` | ``Y`` | ``Y`` | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ -| ``last`` | ``N`` | ``N`` | ``Y`` | ``Y`` | -+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| Aggregation Function | ``DataFrame.agg`` supports? (Y/N/P) | ``Series.agg`` supports? (Y/N/P) | ``DataFrameGroupBy.agg`` supports? (Y/N/P) | ``SeriesGroupBy.agg`` supports? (Y/N/P) | pivot methods support? (Y/N/P) | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``count`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` | +| | For ``axis=1``, ``Y`` if index is | | | | | +| | not a MultiIndex. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``mean`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` | +| | ``N`` for ``axis=1``. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``min`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` | +| | For ``axis=1``, ``Y`` if index is | | | | | +| | not a MultiIndex. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``max`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` | +| | For ``axis=1``, ``Y`` if index is | | | | | +| | not a MultiIndex. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``sum`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` | +| | For ``axis=1``, ``Y`` if index is | | | | | +| | not a MultiIndex. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``median`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``Y`` | +| | ``N`` for ``axis=1``. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``size`` | ``Y`` for ``axis=0``. | ``Y`` | ``Y`` | ``Y`` | ``N`` | +| | ``N`` for ``axis=1``. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``std`` | ``P`` for ``axis=0`` - only when | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | ``Y`` | +| | ``ddof=0`` or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | | +| | ``N`` for ``axis=1``. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``var`` | ``P`` for ``axis=0`` - only when | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | ``P`` - only when ``ddof=0`` | ``Y`` | +| | ``ddof=0`` or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | or ``ddof=1``. | | +| | ``N`` for ``axis=1``. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``quantile`` | ``P`` for ``axis=0`` - only when | ``P`` - only when ``q`` is the | ``P`` - only when ``q`` is the | ``P`` - only when ``q`` is the | ``N`` | +| | ``q`` is the default value or | default value or a scalar. | default value or a scalar. | default value or a scalar. | | +| | a scalar. | | | | | +| | ``N`` for ``axis=1``. | | | | | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ +| ``len`` | ``Y`` | ``Y`` | ``Y`` | ``Y`` | ``N`` | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+-----------------------------------------+ \ No newline at end of file diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index ef243857063..a78e747e173 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -305,9 +305,11 @@ Methods | | | | any ``argfunc`` is not "count", "mean", "min", | | | | | "max", or "sum". N if ``index`` is None, | | | | | ``margins`` is True and ``aggfunc`` is "count" | -| | | | or "mean" or a dictionary. N if ``index`` is None | -| | | | and ``aggfunc`` is a dictionary containing | -| | | | lists of aggfuncs to apply. | +| | | | or "mean" or a dictionary. ``N`` if ``index`` is | +| | | | None and ``aggfunc`` is a dictionary containing | +| | | | lists of aggfuncs to apply. ``N`` if ``aggfunc`` is| +| | | | an `unsupported aggregation | +| | | | function `_ for pivot. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``plot`` | D | | Performed locally on the client | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 8c7ce5120af..babce3f53d0 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -18,8 +18,8 @@ Data manipulations | ``concat`` | P | ``levels`` is not supported, | | | | | ``copy`` is ignored | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not one of | -| | | | "count", "mean", "min", "max", or "sum", or | +| ``crosstab`` | P | | ``N`` if ``aggfunc`` is not a `supported | +| | | | aggregation function `_, | | | | | margins is True, normalize is "all" or True, | | | | | and values is passed. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ @@ -50,8 +50,8 @@ Data manipulations | ``pivot`` | P | | See ``pivot_table`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``pivot_table`` | P | ``observed``, ``margins``, | ``N`` if ``index``, ``columns``, or ``values`` is | -| | | ``sort`` | not str; or MultiIndex; or any ``argfunc`` is not | -| | | | "count", "mean", "min", "max", or "sum" | +| | | ``sort`` | not str; or MultiIndex; or any ``aggfunc`` is not a| +| | | | `supported aggregation function `_ | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``qcut`` | P | | ``N`` if ``labels!=False`` or ``retbins=True``. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 599c0233c68..f6c35df9529 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -265,6 +265,11 @@ class _SnowparkPandasAggregation(NamedTuple): # sum would be True. preserves_snowpark_pandas_types: bool + # Whether Snowflake PIVOT supports this aggregation on axis 0. It seems + # that Snowflake PIVOT supports any aggregation expressed as as single + # function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A) + supported_in_pivot: bool + # This callable takes a single Snowpark column as input and aggregates the # column on axis=0. If None, Snowpark pandas does not support this # aggregation on axis=0. @@ -305,6 +310,12 @@ class SnowflakeAggFunc(NamedTuple): # sum would be True. preserves_snowpark_pandas_types: bool + # Whether Snowflake PIVOT supports this aggregation on axis 0. It seems + # that Snowflake PIVOT supports any aggregation expressed as as single + # function call applied to a single column, e.g. MAX(A), BOOLOR_AND(A). + # This field only makes sense for axis 0 aggregation. + supported_in_pivot: bool + class AggFuncWithLabel(NamedTuple): """ @@ -523,6 +534,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_0_aggregation=count, axis_1_aggregation_skipna=_columns_count, preserves_snowpark_pandas_types=False, + supported_in_pivot=True, ), **_create_pandas_to_snowpark_pandas_aggregation_map( (len, "size"), @@ -532,6 +544,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_1_aggregation_keepna=_columns_count_keep_nulls, axis_1_aggregation_skipna=_columns_count_keep_nulls, preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), ), "first": _SnowparkPandasAggregation( @@ -539,40 +552,45 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_1_aggregation_keepna=lambda *cols: cols[0], axis_1_aggregation_skipna=lambda *cols: coalesce(*cols), preserves_snowpark_pandas_types=True, + supported_in_pivot=False, ), "last": _SnowparkPandasAggregation( axis_0_aggregation=_column_last_value, axis_1_aggregation_keepna=lambda *cols: cols[-1], axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])), preserves_snowpark_pandas_types=True, + supported_in_pivot=False, ), **_create_pandas_to_snowpark_pandas_aggregation_map( ("mean", np.mean), _SnowparkPandasAggregation( axis_0_aggregation=mean, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), ), **_create_pandas_to_snowpark_pandas_aggregation_map( - ("min", np.min), + ("min", np.min, min), _SnowparkPandasAggregation( axis_0_aggregation=min_, axis_1_aggregation_keepna=least, axis_1_aggregation_skipna=_columns_coalescing_min, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), ), **_create_pandas_to_snowpark_pandas_aggregation_map( - ("max", np.max), + ("max", np.max, max), _SnowparkPandasAggregation( axis_0_aggregation=max_, axis_1_aggregation_keepna=greatest, axis_1_aggregation_skipna=_columns_coalescing_max, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), ), **_create_pandas_to_snowpark_pandas_aggregation_map( - ("sum", np.sum), + ("sum", np.sum, sum), _SnowparkPandasAggregation( axis_0_aggregation=sum_, # IMPORTANT: count and sum use python builtin sum to invoke @@ -581,6 +599,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_1_aggregation_keepna=lambda *cols: sum(cols), axis_1_aggregation_skipna=_columns_coalescing_sum, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), ), **_create_pandas_to_snowpark_pandas_aggregation_map( @@ -588,6 +607,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( _SnowparkPandasAggregation( axis_0_aggregation=median, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), ), "idxmax": _SnowparkPandasAggregation( @@ -597,6 +617,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), "idxmin": _SnowparkPandasAggregation( axis_0_aggregation=functools.partial( @@ -605,10 +626,12 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), "skew": _SnowparkPandasAggregation( axis_0_aggregation=skew, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), "all": _SnowparkPandasAggregation( # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. @@ -616,6 +639,7 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( builtin("booland_agg")(col(c)), pandas_lit(True) ), preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), "any": _SnowparkPandasAggregation( # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. @@ -623,12 +647,14 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( builtin("boolor_agg")(col(c)), pandas_lit(False) ), preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), **_create_pandas_to_snowpark_pandas_aggregation_map( ("std", np.std), _SnowparkPandasAggregation( axis_0_aggregation=stddev, preserves_snowpark_pandas_types=True, + supported_in_pivot=True, ), ), **_create_pandas_to_snowpark_pandas_aggregation_map( @@ -638,19 +664,23 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( # variance units are the square of the input column units, so # variance does not preserve types. preserves_snowpark_pandas_types=False, + supported_in_pivot=True, ), ), "array_agg": _SnowparkPandasAggregation( axis_0_aggregation=array_agg, preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), "quantile": _SnowparkPandasAggregation( axis_0_aggregation=column_quantile, preserves_snowpark_pandas_types=True, + supported_in_pivot=False, ), "nunique": _SnowparkPandasAggregation( axis_0_aggregation=count_distinct, preserves_snowpark_pandas_types=False, + supported_in_pivot=False, ), } ) @@ -762,6 +792,7 @@ def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: return SnowflakeAggFunc( snowpark_aggregation=snowpark_aggregation, preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot, ) @@ -800,6 +831,7 @@ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: return SnowflakeAggFunc( snowpark_aggregation, preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + supported_in_pivot=snowpark_pandas_aggregation.supported_in_pivot, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 98925eb9da0..71cf167b23c 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -45,7 +45,6 @@ extract_pandas_label_from_snowflake_quoted_identifier, from_pandas_label, get_distinct_rows, - is_supported_snowflake_pivot_agg_func, pandas_lit, random_name_for_temp_object, to_pandas_label, @@ -522,9 +521,7 @@ def single_pivot_helper( data_column_pandas_labels: new data column pandas labels for this pivot result """ snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0) - if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func( - snowflake_agg_func.snowpark_aggregation - ): + if snowflake_agg_func is None or not snowflake_agg_func.supported_in_pivot: # TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations raise ErrorMessage.not_implemented( f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments." diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 650e8897df4..35a4fedece4 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -8,7 +8,7 @@ import traceback from collections.abc import Hashable, Iterable, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import modin.pandas as pd import numpy as np @@ -42,14 +42,9 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.functions import ( col, - count, equal_nan, floor, iff, - max as max_, - mean, - min as min_, - sum as sum_, to_char, to_timestamp_ntz, to_timestamp_tz, @@ -1196,29 +1191,6 @@ def is_snowpark_pandas_dataframe_or_series_type(obj: Any) -> bool: return isinstance(obj, (pd.DataFrame, pd.Series)) -# TODO: (SNOW-853334) Support other agg functions (any, all, prod, median, skew, kurt, sem, var, std, mad, etc) -snowflake_pivot_agg_func_supported = [ - count, - mean, - min_, - max_, - sum_, -] - - -def is_supported_snowflake_pivot_agg_func(agg_func: Callable) -> bool: - """ - Check if the aggregation function is supported with snowflake pivot. Current supported - aggregation functions are the functions that can be mapped to snowflake builtin function. - - Args: - agg_func: str or Callable. the aggregation function to check - Returns: - Whether it is valid to implement with snowflake or not. - """ - return agg_func in snowflake_pivot_agg_func_supported - - def convert_snowflake_string_constant_to_python_string(identifier: str) -> str: """ Convert a snowflake string constant to a python constant, this removes surrounding single quotes diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 55bd416857f..8f6bf210a0b 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9094,12 +9094,6 @@ def pivot_table( if not sort: raise NotImplementedError("Not implemented not sorted") - # TODO: (SNOW-853334) Support callable agg functions - if aggfunc and callable(aggfunc): - raise NotImplementedError( - f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(aggfunc, agg_kwargs={})} with the given arguments." - ) - if columns is not None and isinstance(columns, Hashable): columns = [columns] diff --git a/tests/integ/modin/crosstab/test_crosstab.py b/tests/integ/modin/crosstab/test_crosstab.py index a2d2f680454..59df70001d3 100644 --- a/tests/integ/modin/crosstab/test_crosstab.py +++ b/tests/integ/modin/crosstab/test_crosstab.py @@ -13,6 +13,24 @@ from tests.integ.modin.utils import eval_snowpark_pandas_result from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker +# See SNOW-1892612 for why we have this list. +AGGFUNCS_THAT_CANNOT_PRODUCE_NAN = ( + "median", + np.median, + "count", + "mean", + np.mean, + min, + "min", + np.min, + max, + "max", + np.max, + "sum", + np.sum, + sum, +) + @pytest.mark.parametrize("dropna", [True, False]) class TestCrosstab: @@ -371,7 +389,7 @@ def test_normalize_and_margins(self, dropna, normalize, a, b, c): ) @pytest.mark.parametrize("normalize", [0, 1, "index", "columns"]) - @pytest.mark.parametrize("aggfunc", ["count", "mean", "min", "max", "sum"]) + @pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN) def test_normalize_margins_and_values(self, dropna, normalize, aggfunc, a, b, c): counts = { "columns": [3, 10 if dropna else 13, 3], @@ -398,7 +416,7 @@ def eval_func(lib): dropna=dropna, aggfunc=aggfunc, ) - if aggfunc == "sum": + if aggfunc in (sum, "sum", np.sum, sum): # When normalizing the data, we apply the normalization function to the # entire table (including margins), which requires us to multiply by 2 # (since the function takes the sum over the rows, and the margins row is @@ -419,7 +437,7 @@ def eval_func(lib): eval_func, ) - @pytest.mark.parametrize("aggfunc", ["count", "mean", "min", "max", "sum"]) + @pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN) def test_margins_and_values(self, dropna, aggfunc, a, b, c): vals = np.array([12, 10, 9, 4, 3, 49, 19, 20, 21, 34, 0]) @@ -448,7 +466,7 @@ def eval_func(lib): ) @pytest.mark.parametrize("normalize", [0, 1, True, "all", "index", "columns"]) - @pytest.mark.parametrize("aggfunc", ["count", "mean", "min", "max", "sum"]) + @pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN) def test_normalize_and_values(self, dropna, normalize, aggfunc, a, b, c): counts = { "columns": [2, 4 if dropna else 6], @@ -474,7 +492,7 @@ def eval_func(lib): dropna=dropna, aggfunc=aggfunc, ) - if aggfunc in ["sum", "max"]: + if aggfunc in ("sum", "max", np.sum, max, np.max, sum): # When normalizing the data, we apply the normalization function to the # entire table (including margins), which requires us to multiply by 2 # (since the function takes the sum over the rows, and the margins row is @@ -495,7 +513,7 @@ def eval_func(lib): ) @pytest.mark.parametrize("normalize", ["all", True]) - @pytest.mark.parametrize("aggfunc", ["count", "mean", "min", "max", "sum"]) + @pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN) @sql_count_checker(query_count=0) def test_normalize_margins_and_values_not_supported( self, dropna, normalize, aggfunc, a, b, c @@ -517,7 +535,7 @@ def test_normalize_margins_and_values_not_supported( aggfunc=aggfunc, ) - @pytest.mark.parametrize("aggfunc", ["count", "mean", "min", "max", "sum"]) + @pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN) def test_values(self, dropna, aggfunc, basic_crosstab_dfs): query_count = 1 join_count = 2 if dropna else 3 @@ -536,7 +554,7 @@ def test_values(self, dropna, aggfunc, basic_crosstab_dfs): ), ) - @pytest.mark.parametrize("aggfunc", ["count", "mean", "min", "max", "sum"]) + @pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN) def test_values_series_like(self, dropna, aggfunc, basic_crosstab_dfs): query_count = 5 join_count = 2 if dropna else 3 @@ -567,6 +585,49 @@ def eval_func(df): eval_func, ) + @pytest.mark.parametrize( + "aggfunc", + ( + # std is NaN for < 2 values + "std", + np.std, + # var is NaN for < 1 values + "var", + np.var, + # skew is NaN for < 3 values + "skew", + ), + ) + def test_aggfuncs_that_may_produce_nan(self, dropna, aggfunc): + """ + Test aggfuncs that may produce NaN. + + We test these aggfuncs separately because when dropna=True and some + aggfuncs produce NaN, pandas has some bugs: + + - https://github.com/pandas-dev/pandas/issues/60768 + - https://github.com/pandas-dev/pandas/issues/60767 + + We design these test cases so that the aggfuncs do not produce NaN, and + we can compare with pandas. + + TODO(SNOW-1892612): Once pandas fixes these bugs, merge these test cases + with the rest of the test suite by adding these aggfuncs to the lists + of aggfuncs that we test in other functions. + """ + with SqlCounter(query_count=1, join_count=(2 if dropna else 3)): + eval_snowpark_pandas_result( + pd, + native_pd, + lambda lib: lib.crosstab( + index=["index1"] * 3 + ["index2"] * 3, + columns=["column1"] * 6, + values=list(range(6)), + dropna=dropna, + aggfunc=aggfunc, + ), + ) + @sql_count_checker(query_count=0) def test_values_unsupported_aggfunc(basic_crosstab_dfs): @@ -574,13 +635,13 @@ def test_values_unsupported_aggfunc(basic_crosstab_dfs): with pytest.raises( NotImplementedError, - match="Snowpark pandas DataFrame.pivot_table does not yet support the aggregation 'median' with the given arguments.", + match="Snowpark pandas DataFrame.pivot_table does not yet support the aggregation 'size' with the given arguments.", ): pd.crosstab( native_df["species"].values, native_df["favorite_food"].values, values=native_df["age"].values, - aggfunc="median", + aggfunc="size", dropna=False, ) @@ -593,13 +654,13 @@ def test_values_series_like_unsupported_aggfunc(basic_crosstab_dfs): with pytest.raises( NotImplementedError, - match="Snowpark pandas DataFrame.pivot_table does not yet support the aggregation 'median' with the given arguments.", + match="Snowpark pandas DataFrame.pivot_table does not yet support the aggregation 'size' with the given arguments.", ): snow_df = pd.crosstab( snow_df["species"], snow_df["favorite_food"], values=snow_df["age"], - aggfunc="median", + aggfunc="size", dropna=False, ) diff --git a/tests/integ/modin/pivot/test_pivot_table_negative.py b/tests/integ/modin/pivot/test_pivot_table_negative.py index aa15304ac58..ea5c6ff456e 100644 --- a/tests/integ/modin/pivot/test_pivot_table_negative.py +++ b/tests/integ/modin/pivot/test_pivot_table_negative.py @@ -14,6 +14,7 @@ pivot_table_test_helper_expects_exception, ) from tests.integ.utils.sql_counter import sql_count_checker +import re @pytest.mark.parametrize( @@ -147,14 +148,42 @@ class Baz: with pytest.raises(NotImplementedError, match="Not implemented non-string"): snow_df2.pivot_table(index="A", columns="B", values=[baz]) - with pytest.raises(NotImplementedError, match="foo"): - snow_df.pivot_table(index="A", columns="C", values="E", aggfunc="foo") +@sql_count_checker(query_count=0) +@pytest.mark.parametrize( + "aggfunc,name_in_error", + [ + ("foo", "'foo'"), + ("kurt", "'kurt'"), + ("prod", "'prod'"), + ("sem", "'sem'"), + (np.argmax, "np.argmax"), + (np.argmin, "np.argmin"), + ( + "all", + "'all'", + ), + (np.all, "np.all"), + ("any", "'any'"), + (np.any, "np.any"), + ("size", "'size'"), + (len, ""), + ("nunique", "'nunique'"), + ("idxmax", "'idxmax'"), + ("idxmin", "'idxmin'"), + ], +) +def test_not_implemented_single_aggfunc(df_data, aggfunc, name_in_error): with pytest.raises( NotImplementedError, - match="median", + match=re.escape( + "Snowpark pandas DataFrame.pivot_table does not yet support " + + f"the aggregation {name_in_error} with the given arguments." + ), ): - snow_df.pivot_table(index="A", columns="C", values="D", aggfunc="median") + pd.DataFrame(df_data).pivot_table( + index="A", columns="C", values="E", aggfunc=aggfunc + ) def sensitive_function_name(col: native_pd.Series) -> int: diff --git a/tests/integ/modin/pivot/test_pivot_table_single.py b/tests/integ/modin/pivot/test_pivot_table_single.py index 4827527604e..9e5dfbfa51f 100644 --- a/tests/integ/modin/pivot/test_pivot_table_single.py +++ b/tests/integ/modin/pivot/test_pivot_table_single.py @@ -7,6 +7,7 @@ # This test file contains tests that execute a single underlying snowpark/snowflake pivot query. import pytest +import numpy as np import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.pivot.pivot_utils import ( @@ -34,13 +35,26 @@ def test_pivot_table_no_index_single_column_single_value(df_data): @pytest.mark.parametrize( "aggfunc", - [ + ( "mean", + np.mean, "sum", + np.sum, "min", + np.min, + min, "max", + np.max, + max, "count", - ], + "median", + np.median, + "skew", + "std", + np.std, + "var", + np.var, + ), ) @sql_count_checker(query_count=1) def test_pivot_table_single_index_single_column_single_value(df_data, aggfunc): diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index d1e45df0953..fb56d963a76 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -40,8 +40,11 @@ ("count", {}, 1, True), ("size", {}, 0, True), ("size", {}, 1, True), + (sum, {}, 0, True), (len, {}, 0, True), (len, {}, 1, True), + (min, {}, 0, True), + (max, {}, 0, True), ("min", {}, 0, True), ("min", {}, 1, True), ("test", {}, 0, False), @@ -120,8 +123,13 @@ def test_check_aggregation_snowflake_execution_capability_by_args( @pytest.mark.parametrize( "agg_func, agg_kwargs, axis, expected", [ - (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)), - ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)), + (np.sum, {}, 0, SnowflakeAggFunc(sum_, True, supported_in_pivot=True)), + ( + "max", + {"skipna": False}, + 1, + SnowflakeAggFunc(greatest, True, supported_in_pivot=True), + ), ("test", {}, 0, None), ], ) @@ -147,6 +155,7 @@ def test_get_snowflake_agg_func_with_no_implementation_on_axis_0(): preserves_snowpark_pandas_types=True, axis_1_aggregation_keepna=greatest, axis_1_aggregation_skipna=greatest, + supported_in_pivot=True, ) } ),