Skip to content

Commit df87f69

Browse files
authored
Fix multiple grouping with missing groups (pydata#9650)
* Fix multiple grouping with missing groups Closes pydata#9360 * Small repr improvement * Small optimization in mask * Add whats-new * fix doctests
1 parent 01831a4 commit df87f69

File tree

5 files changed

+49
-20
lines changed

5 files changed

+49
-20
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ Bug fixes
6868
- Fix the safe_chunks validation option on the to_zarr method
6969
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak
7070
<https://github.com/josephnowak>`_.
71+
- Fix binning by multiple variables where some bins have no observations. (:issue:`9630`).
72+
By `Deepak Cherian <https://github.com/dcherian>`_.
7173

7274
Documentation
7375
~~~~~~~~~~~~~

xarray/core/dataarray.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6800,7 +6800,7 @@ def groupby(
68006800
68016801
>>> da.groupby("letters")
68026802
<DataArrayGroupBy, grouped over 1 grouper(s), 2 groups in total:
6803-
'letters': 2 groups with labels 'a', 'b'>
6803+
'letters': 2/2 groups present with labels 'a', 'b'>
68046804
68056805
Execute a reduction
68066806
@@ -6816,8 +6816,8 @@ def groupby(
68166816
68176817
>>> da.groupby(["letters", "x"])
68186818
<DataArrayGroupBy, grouped over 2 grouper(s), 8 groups in total:
6819-
'letters': 2 groups with labels 'a', 'b'
6820-
'x': 4 groups with labels 10, 20, 30, 40>
6819+
'letters': 2/2 groups present with labels 'a', 'b'
6820+
'x': 4/4 groups present with labels 10, 20, 30, 40>
68216821
68226822
Use Grouper objects to express more complicated GroupBy operations
68236823

xarray/core/dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10403,7 +10403,7 @@ def groupby(
1040310403
1040410404
>>> ds.groupby("letters")
1040510405
<DatasetGroupBy, grouped over 1 grouper(s), 2 groups in total:
10406-
'letters': 2 groups with labels 'a', 'b'>
10406+
'letters': 2/2 groups present with labels 'a', 'b'>
1040710407
1040810408
Execute a reduction
1040910409
@@ -10420,8 +10420,8 @@ def groupby(
1042010420
1042110421
>>> ds.groupby(["letters", "x"])
1042210422
<DatasetGroupBy, grouped over 2 grouper(s), 8 groups in total:
10423-
'letters': 2 groups with labels 'a', 'b'
10424-
'x': 4 groups with labels 10, 20, 30, 40>
10423+
'letters': 2/2 groups present with labels 'a', 'b'
10424+
'x': 4/4 groups present with labels 10, 20, 30, 40>
1042510425
1042610426
Use Grouper objects to express more complicated GroupBy operations
1042710427

xarray/core/groupby.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def factorize(self) -> EncodedGroups:
454454
# At this point all arrays have been factorized.
455455
codes = tuple(grouper.codes for grouper in groupers)
456456
shape = tuple(grouper.size for grouper in groupers)
457+
masks = tuple((code == -1) for code in codes)
457458
# We broadcast the codes against each other
458459
broadcasted_codes = broadcast(*codes)
459460
# This fully broadcasted DataArray is used as a template later
@@ -464,24 +465,18 @@ def factorize(self) -> EncodedGroups:
464465
)
465466
# NaNs; as well as values outside the bins are coded by -1
466467
# Restore these after the raveling
467-
mask = functools.reduce(
468-
np.logical_or, # type: ignore[arg-type]
469-
[(code == -1) for code in broadcasted_codes],
470-
)
468+
broadcasted_masks = broadcast(*masks)
469+
mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type]
471470
_flatcodes[mask] = -1
472471

473-
midx = pd.MultiIndex.from_product(
474-
(grouper.unique_coord.data for grouper in groupers),
472+
full_index = pd.MultiIndex.from_product(
473+
(grouper.full_index.values for grouper in groupers),
475474
names=tuple(grouper.name for grouper in groupers),
476475
)
477476
# Constructing an index from the product is wrong when there are missing groups
478477
# (e.g. binning, resampling). Account for that now.
479-
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
478+
midx = full_index[np.sort(pd.unique(_flatcodes[~mask]))]
480479

481-
full_index = pd.MultiIndex.from_product(
482-
(grouper.full_index.values for grouper in groupers),
483-
names=tuple(grouper.name for grouper in groupers),
484-
)
485480
dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers)
486481

487482
coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name)
@@ -684,7 +679,7 @@ def __repr__(self) -> str:
684679
for grouper in self.groupers:
685680
coord = grouper.unique_coord
686681
labels = ", ".join(format_array_flat(coord, 30).split())
687-
text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}"
682+
text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}"
688683
return text + ">"
689684

690685
def _iter_grouped(self) -> Iterator[T_Xarray]:

xarray/tests/test_groupby.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def test_groupby_repr(obj, dim) -> None:
606606
N = len(np.unique(obj[dim]))
607607
expected = f"<{obj.__class__.__name__}GroupBy"
608608
expected += f", grouped over 1 grouper(s), {N} groups in total:"
609-
expected += f"\n {dim!r}: {N} groups with labels "
609+
expected += f"\n {dim!r}: {N}/{N} groups present with labels "
610610
if dim == "x":
611611
expected += "1, 2, 3, 4, 5>"
612612
elif dim == "y":
@@ -623,7 +623,7 @@ def test_groupby_repr_datetime(obj) -> None:
623623
actual = repr(obj.groupby("t.month"))
624624
expected = f"<{obj.__class__.__name__}GroupBy"
625625
expected += ", grouped over 1 grouper(s), 12 groups in total:\n"
626-
expected += " 'month': 12 groups with labels "
626+
expected += " 'month': 12/12 groups present with labels "
627627
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>"
628628
assert actual == expected
629629

@@ -2953,3 +2953,35 @@ def test_groupby_transpose():
29532953
second = data.groupby("x").sum()
29542954

29552955
assert_identical(first, second.transpose(*first.dims))
2956+
2957+
2958+
def test_groupby_multiple_bin_grouper_missing_groups():
2959+
from numpy import nan
2960+
2961+
ds = xr.Dataset(
2962+
{"foo": (("z"), np.arange(12))},
2963+
coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))},
2964+
)
2965+
2966+
actual = ds.groupby(
2967+
x=BinGrouper(np.arange(0, 13, 4)), y=BinGrouper(bins=np.arange(0, 16, 2))
2968+
).count()
2969+
expected = Dataset(
2970+
{
2971+
"foo": (
2972+
("x_bins", "y_bins"),
2973+
np.array(
2974+
[
2975+
[2.0, 2.0, nan, nan, nan, nan, nan],
2976+
[nan, nan, 2.0, 2.0, nan, nan, nan],
2977+
[nan, nan, nan, nan, 2.0, 1.0, nan],
2978+
]
2979+
),
2980+
)
2981+
},
2982+
coords={
2983+
"x_bins": ("x_bins", pd.IntervalIndex.from_breaks(np.arange(0, 13, 4))),
2984+
"y_bins": ("y_bins", pd.IntervalIndex.from_breaks(np.arange(0, 16, 2))),
2985+
},
2986+
)
2987+
assert_identical(actual, expected)

0 commit comments

Comments
 (0)