Skip to content

Commit 01831a4

Browse files
authored
flox: Properly propagate multiindex (pydata#9649)
* flox: Properly propagate multiindex Closes pydata#9648 * skip test on old pandas * small optimization * fix
1 parent cfaa72f commit 01831a4

File tree

6 files changed

+53
-37
lines changed

6 files changed

+53
-37
lines changed

doc/whats-new.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Bug fixes
6363
the non-missing times could in theory be encoded with integers
6464
(:issue:`9488`, :pull:`9497`). By `Spencer Clark
6565
<https://github.com/spencerkclark>`_.
66-
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
66+
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`, :issue:`9648`).
6767
By `Deepak Cherian <https://github.com/dcherian>`_.
6868
- Fix the safe_chunks validation option on the to_zarr method
6969
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak

xarray/core/coordinates.py

+11
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,14 @@ def create_coords_with_default_indexes(
11161116
new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)
11171117

11181118
return new_coords
1119+
1120+
1121+
def _coordinates_from_variable(variable: Variable) -> Coordinates:
1122+
from xarray.core.indexes import create_default_index_implicit
1123+
1124+
(name,) = variable.dims
1125+
new_index, index_vars = create_default_index_implicit(variable)
1126+
indexes = {k: new_index for k in index_vars}
1127+
new_vars = new_index.create_variables()
1128+
new_vars[name].attrs = variable.attrs
1129+
return Coordinates(new_vars, indexes)

xarray/core/groupby.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2222
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
2323
from xarray.core.concat import concat
24-
from xarray.core.coordinates import Coordinates
24+
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2525
from xarray.core.formatting import format_array_flat
2626
from xarray.core.indexes import (
27-
PandasIndex,
2827
PandasMultiIndex,
2928
filter_indexes_from_coords,
3029
)
30+
from xarray.core.merge import merge_coords
3131
from xarray.core.options import OPTIONS, _get_keep_attrs
3232
from xarray.core.types import (
3333
Dims,
@@ -851,7 +851,6 @@ def _flox_reduce(
851851
from flox.xarray import xarray_reduce
852852

853853
from xarray.core.dataset import Dataset
854-
from xarray.groupers import BinGrouper
855854

856855
obj = self._original_obj
857856
variables = (
@@ -901,13 +900,6 @@ def _flox_reduce(
901900
# set explicitly to avoid unnecessarily accumulating count
902901
kwargs["min_count"] = 0
903902

904-
unindexed_dims: tuple[Hashable, ...] = tuple(
905-
grouper.name
906-
for grouper in self.groupers
907-
if isinstance(grouper.group, _DummyGroup)
908-
and not isinstance(grouper.grouper, BinGrouper)
909-
)
910-
911903
parsed_dim: tuple[Hashable, ...]
912904
if isinstance(dim, str):
913905
parsed_dim = (dim,)
@@ -963,26 +955,29 @@ def _flox_reduce(
963955
# we did end up reducing over dimension(s) that are
964956
# in the grouped variable
965957
group_dims = set(grouper.group.dims)
966-
new_coords = {}
958+
new_coords = []
959+
to_drop = []
967960
if group_dims.issubset(set(parsed_dim)):
968-
new_indexes = {}
969961
for grouper in self.groupers:
970962
output_index = grouper.full_index
971963
if isinstance(output_index, pd.RangeIndex):
964+
# flox always assigns an index so we must drop it here if we don't need it.
965+
to_drop.append(grouper.name)
972966
continue
973-
name = grouper.name
974-
new_coords[name] = IndexVariable(
975-
dims=name, data=np.array(output_index), attrs=grouper.codes.attrs
976-
)
977-
index_cls = (
978-
PandasIndex
979-
if not isinstance(output_index, pd.MultiIndex)
980-
else PandasMultiIndex
967+
new_coords.append(
968+
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
969+
# all associated levels properly.
970+
_coordinates_from_variable(
971+
IndexVariable(
972+
dims=grouper.name,
973+
data=output_index,
974+
attrs=grouper.codes.attrs,
975+
)
976+
)
981977
)
982-
new_indexes[name] = index_cls(output_index, dim=name)
983978
result = result.assign_coords(
984-
Coordinates(new_coords, new_indexes)
985-
).drop_vars(unindexed_dims)
979+
Coordinates._construct_direct(*merge_coords(new_coords))
980+
).drop_vars(to_drop)
986981

987982
# broadcast any non-dim coord variables that don't
988983
# share all dimensions with the grouper

xarray/groupers.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
1818
from xarray.core import duck_array_ops
19-
from xarray.core.coordinates import Coordinates
19+
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2020
from xarray.core.dataarray import DataArray
2121
from xarray.core.groupby import T_Group, _DummyGroup
2222
from xarray.core.indexes import safe_cast_to_index
@@ -42,17 +42,6 @@
4242
RESAMPLE_DIM = "__resample_dim__"
4343

4444

45-
def _coordinates_from_variable(variable: Variable) -> Coordinates:
46-
from xarray.core.indexes import create_default_index_implicit
47-
48-
(name,) = variable.dims
49-
new_index, index_vars = create_default_index_implicit(variable)
50-
indexes = {k: new_index for k in index_vars}
51-
new_vars = new_index.create_variables()
52-
new_vars[name].attrs = variable.attrs
53-
return Coordinates(new_vars, indexes)
54-
55-
5645
@dataclass(init=False)
5746
class EncodedGroups:
5847
"""

xarray/tests/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _importorskip(
135135
has_pint, requires_pint = _importorskip("pint")
136136
has_numexpr, requires_numexpr = _importorskip("numexpr")
137137
has_flox, requires_flox = _importorskip("flox")
138-
has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2")
138+
has_pandas_ge_2_2, requires_pandas_ge_2_2 = _importorskip("pandas", "2.2")
139139
has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0")
140140

141141

xarray/tests/test_groupby.py

+21
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
requires_dask,
3636
requires_flox,
3737
requires_flox_0_9_12,
38+
requires_pandas_ge_2_2,
3839
requires_scipy,
3940
)
4041

@@ -145,6 +146,26 @@ def test_multi_index_groupby_sum() -> None:
145146
assert_equal(expected, actual)
146147

147148

149+
@requires_pandas_ge_2_2
150+
def test_multi_index_propagation():
151+
# regression test for GH9648
152+
times = pd.date_range("2023-01-01", periods=4)
153+
locations = ["A", "B"]
154+
data = [[0.5, 0.7], [0.6, 0.5], [0.4, 0.6], [0.4, 0.9]]
155+
156+
da = xr.DataArray(
157+
data, dims=["time", "location"], coords={"time": times, "location": locations}
158+
)
159+
da = da.stack(multiindex=["time", "location"])
160+
grouped = da.groupby("multiindex")
161+
162+
with xr.set_options(use_flox=True):
163+
actual = grouped.sum()
164+
with xr.set_options(use_flox=False):
165+
expected = grouped.first()
166+
assert_identical(actual, expected)
167+
168+
148169
def test_groupby_da_datetime() -> None:
149170
# test groupby with a DataArray of dtype datetime for GH1132
150171
# create test data

0 commit comments

Comments
 (0)