Skip to content

Draft: Extend idxmin and idxmax to accept multiple dimensions #10125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 91 additions & 36 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5894,26 +5894,30 @@ def pad(

def idxmin(
self,
dim: Hashable | None = None,
dim: Dims = None,
*,
skipna: bool | None = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool | None = None,
) -> Self:
"""Return the coordinate label of the minimum value along a dimension.
) -> Self | dict[Hashable, Self]:
"""Return the coordinate label of the minimum value along one or more dimensions.

Returns a new `DataArray` named after the dimension with the values of
the coordinate labels along that dimension corresponding to minimum
values along that dimension.

If a sequence is passed to 'dim', then result is returned as dict of DataArrays.
If a single str is passed to 'dim' then returns a DataArray.

In comparison to :py:meth:`~DataArray.argmin`, this returns the
coordinate label while :py:meth:`~DataArray.argmin` returns the index.
Internally, this method uses argmin to locate the minimum values.

Parameters
----------
dim : str, optional
Dimension over which to apply `idxmin`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the minimum. By default, finds minimum over
all dimensions if array is 1D, otherwise requires explicit specification.
skipna : bool or None, default: None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for ``float``, ``complex``, and ``object``
Expand All @@ -5932,9 +5936,9 @@ def idxmin(

Returns
-------
reduced : DataArray
New `DataArray` object with `idxmin` applied to its data and the
indicated dimension removed.
reduced : DataArray or dict of DataArray
New `DataArray` object(s) with `idxmin` applied to its data and the
indicated dimension(s) removed.

See Also
--------
Expand Down Expand Up @@ -5979,38 +5983,70 @@ def idxmin(
array([16., 0., 4.])
Coordinates:
* y (y) int64 24B -1 0 1
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)
>>> array.idxmin(dim=["x", "y"])
{'x': <xarray.DataArray 'x' ()> Size: 8B
array(0.),
'y': <xarray.DataArray 'y' ()> Size: 8B
array(0)}
Comment on lines +5986 to +5990
Copy link
Member

@shoyer shoyer Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example looks right to me, but does the code actually implement this behavior?

The unit tests don't include anything like this. 2d idxmin looks like it reduces over each dimension independently, and would return a dict of 1D results.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think my doctest run didn't work, you're right.

Moving this to draft

"""
# Check if dim is multi-dimensional
# TODO: resolve how we're checking for singular dims; we seem to have a few
# different ways through the codebase. This is reflected in requiring some mypy
# ignores for some corner cases.
if (
dim is ... or (isinstance(dim, Iterable) and not isinstance(dim, str))
) and not isinstance(dim, tuple):
# For multiple dimensions, process each dimension separately
# This matches the behavior pattern of argmin when given multiple dimensions,
# but returns coordinate labels instead of indices
result = {}
for k in dim if dim is not ... else self.dims:
result[k] = self.idxmin(
dim=k, # type: ignore[arg-type] # k is Hashable from self.dims
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)
Comment on lines +6004 to +6009
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a recursive call to idxmin of some kind?

return result # type: ignore[return-value]
else:
# Use the existing implementation for single dimension
# This wraps argmin through the _calc_idxminmax helper function
# which converts indices to coordinate values
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)

def idxmax(
self,
dim: Hashable = None,
dim: Dims = None,
*,
skipna: bool | None = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool | None = None,
) -> Self:
"""Return the coordinate label of the maximum value along a dimension.
) -> Self | dict[Hashable, Self]:
"""Return the coordinate label of the maximum value along one or more dimensions.

Returns a new `DataArray` named after the dimension with the values of
the coordinate labels along that dimension corresponding to maximum
values along that dimension.

If a sequence is passed to 'dim', then result is returned as dict of DataArrays.
If a single str is passed to 'dim' then returns a DataArray.

In comparison to :py:meth:`~DataArray.argmax`, this returns the
coordinate label while :py:meth:`~DataArray.argmax` returns the index.
Internally, this method uses argmax to locate the maximum values.

Parameters
----------
dim : Hashable, optional
Dimension over which to apply `idxmax`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the maximum. By default, finds maximum over
all dimensions if array is 1D, otherwise requires explicit specification.
skipna : bool or None, default: None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for ``float``, ``complex``, and ``object``
Expand All @@ -6029,9 +6065,9 @@ def idxmax(

Returns
-------
reduced : DataArray
New `DataArray` object with `idxmax` applied to its data and the
indicated dimension removed.
reduced : DataArray or dict of DataArray
New `DataArray` object(s) with `idxmax` applied to its data and the
indicated dimension(s) removed.

See Also
--------
Expand Down Expand Up @@ -6076,15 +6112,34 @@ def idxmax(
array([0., 4., 4.])
Coordinates:
* y (y) int64 24B -1 0 1
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)
>>> array.idxmax(dim=["x", "y"])
{'x': <xarray.DataArray 'x' ()> Size: 8B
array(0.),
'y': <xarray.DataArray 'y' ()> Size: 8B
array(-1)}
"""
# Copy/paste from idxmin; comments are there
if (
dim is ... or (isinstance(dim, Iterable) and not isinstance(dim, str))
) and not isinstance(dim, tuple):
result = {}
for k in dim if dim is not ... else self.dims:
result[k] = self.idxmax(
dim=k, # type: ignore[arg-type] # k is Hashable from self.dims
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)
return result # type: ignore[return-value]
else:
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)

def argmin(
self,
Expand Down
32 changes: 32 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5576,6 +5576,22 @@ def test_idxmin(
else:
ar0 = ar0_raw

# Test multi-dimensional idxmin
if not np.isnan(minindex).any():
# Get the indices for the minimum values along both dimensions
idx_y = ar0.argmin(dim="x")
idx_x = ar0.argmin(dim="y")

# Get the coordinate labels for these minimum values using idxmin
idxmin_result = ar0.idxmin(dim=["x", "y"])

# Verify we have the expected keys in the result dictionary
assert set(idxmin_result.keys()) == {"x", "y"}

# Verify the values match what we'd expect from individually applying idxmin
assert_identical(idxmin_result["x"], ar0.idxmin(dim="x"))
assert_identical(idxmin_result["y"], ar0.idxmin(dim="y"))

assert_identical(ar0, ar0)

# No dimension specified
Expand Down Expand Up @@ -5717,6 +5733,22 @@ def test_idxmax(
else:
ar0 = ar0_raw

# Test multi-dimensional idxmax
if not np.isnan(maxindex).any():
# Get the indices for the maximum values along both dimensions
idx_y = ar0.argmax(dim="x")
idx_x = ar0.argmax(dim="y")

# Get the coordinate labels for these maximum values using idxmax
idxmax_result = ar0.idxmax(dim=["x", "y"])

# Verify we have the expected keys in the result dictionary
assert set(idxmax_result.keys()) == {"x", "y"}

# Verify the values match what we'd expect from individually applying idxmax
assert_identical(idxmax_result["x"], ar0.idxmax(dim="x"))
assert_identical(idxmax_result["y"], ar0.idxmax(dim="y"))

# No dimension specified
with pytest.raises(ValueError):
ar0.idxmax()
Expand Down
Loading