From 0707a8b782428c5fc846ffd739da70e570f6d412 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 10:36:22 +0100 Subject: [PATCH 01/25] typing fixes and tweaks --- xarray/core/coordinates.py | 8 ++++---- xarray/core/dataarray.py | 6 +++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 408e9e630ee..9d88cca7fc7 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool: return self.to_dataset().identical(other.to_dataset()) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: # redirect to DatasetCoordinates._update_coords self._data.coords._update_coords(coords, indexes) @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset: return self._data._copy_listed(names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: variables = self._data._variables.copy() variables.update(coords) @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset: return self._data.dataset._copy_listed(self._names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: from xarray.core.datatree import check_alignment @@ -964,7 +964,7 @@ def __getitem__(self, key: Hashable) -> T_DataArray: return self._data._getitem_coord(key) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b5f3cd5c200..7156277e2df 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -132,7 +132,11 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims(shape, coords, dim): +def _check_coords_dims( + shape: tuple[int, ...], + coords: Coordinates | Mapping[Hashable, Variable], + dim: tuple[Hashable, ...], +): sizes = dict(zip(dim, shape, strict=True)) for k, v in coords.items(): if any(d not in dim for d in v.dims): From 75086ef62391eaff004339148f544277cda4b7bb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 10:36:53 +0100 Subject: [PATCH 02/25] add Index.validate_dataarray_coord() --- xarray/core/indexes.py | 47 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0b4eee7b21c..fbd2edefe68 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -195,6 +195,53 @@ def create_variables( else: return {} + def validate_dataarray_coord( + self, + name: Hashable, + var: Variable, + dims: set[Hashable], + ): + """Validate an index coordinate to be included in a DataArray. + + This method is called repeatedly for each coordinate associated with + this index when creating a new DataArray (via its constructor or from a + Dataset) or updating an existing one. + + By default raises an error if the dimensions of the coordinate variable + do conflict with the array dimensions (strict DataArray model). + + This method may be overridden in Index subclasses, e.g., to include index + coordinates that does not conform with the strict DataArray model. This + is useful for example to include (n+1)-dimensional cell boundary + coordinates attached to an index. + + When a DataArray is constructed from a Dataset, instead of raising the + error Xarray will drop the index and propagate the index coordinates + according to the default rules for DataArray (i.e., depending on their + dimensions). + + Parameters + ---------- + name : Hashable + Name of a coordinate associated to this index. + var : Variable + Coordinate variable object. + dims: tuple + Dataarray's dimensions. + + Raises + ------ + ValueError + When validation fails. + + """ + if any(d not in dims for d in var.dims): + raise ValueError( + f"coordinate {name} has dimensions {var.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dims}" + ) + def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a ``TypeError`` if this is not supported. From 8aaf2b800b490781f54b4dad2205ea2c4fca9b59 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 10:40:50 +0100 Subject: [PATCH 03/25] Dataset._construct_dataarray: validate index coord --- xarray/core/dataset.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 67d9b6642d1..2db8229be31 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1212,7 +1212,18 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + var_dims = set(self._variables[k].dims) + if k in self._indexes: + try: + self._indexes[k].validate_dataarray_coord( + k, self._variables[k], needed_dims + ) + coords[k] = self._variables[k] + except ValueError: + # failback to strict DataArray model check (index may be dropped later) + if var_dims <= needed_dims: + coords[k] = self._variables[k] + elif k in self._coord_names and var_dims <= needed_dims: coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) From c9b4baae99a3a69973c22ecf1bfbb278cf97c2d5 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 10:44:35 +0100 Subject: [PATCH 04/25] DataArray init: validate index coord --- xarray/core/dataarray.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7156277e2df..f2de2e8dc71 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -138,8 +138,19 @@ def _check_coords_dims( dim: tuple[Hashable, ...], ): sizes = dict(zip(dim, shape, strict=True)) + + indexes: Mapping[Hashable, Index] + if isinstance(coords, Coordinates): + indexes = coords.xindexes + else: + indexes = {} + + dim_set = set(dim) + for k, v in coords.items(): - if any(d not in dim for d in v.dims): + if k in indexes: + indexes[k].validate_dataarray_coord(k, v, dim_set) + elif any(d not in dim for d in v.dims): raise ValueError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " @@ -147,7 +158,7 @@ def _check_coords_dims( ) for d, s in v.sizes.items(): - if s != sizes[d]: + if d in sizes and s != sizes[d]: raise ValueError( f"conflicting sizes for dimension {d!r}: " f"length {sizes[d]} on the data but length {s} on " From a47523fc4cea059e08ea7831151b9dd57aba01c6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 10:45:01 +0100 Subject: [PATCH 05/25] clean-up old TODO --- xarray/core/coordinates.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9d88cca7fc7..e33f3cab907 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -974,12 +974,7 @@ def _update_coords( "cannot add coordinates with new dimensions to a DataArray" ) self._data._coords = coords - - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._indexes = indexes def _drop_coords(self, coord_names): # should drop indexed coordinates only From 551808a518ce16bdd2cde2462467a720c40307f3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 11:52:59 +0100 Subject: [PATCH 06/25] refactor dataarray coord update --- xarray/core/coordinates.py | 13 ++++++------- xarray/core/dataarray.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index e33f3cab907..220321e9442 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -966,13 +966,12 @@ def __getitem__(self, key: Hashable) -> T_DataArray: def _update_coords( self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: - coords_plus_data = coords.copy() - coords_plus_data[_THIS_ARRAY] = self._data.variable - dims = calculate_dimensions(coords_plus_data) - if not set(dims) <= set(self.dims): - raise ValueError( - "cannot add coordinates with new dimensions to a DataArray" - ) + from xarray.core.dataarray import check_dataarray_coords + + check_dataarray_coords( + self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims + ) + self._data._coords = coords self._data._indexes = indexes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f2de2e8dc71..6e1d21fdb45 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -132,7 +132,7 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims( +def check_dataarray_coords( shape: tuple[int, ...], coords: Coordinates | Mapping[Hashable, Variable], dim: tuple[Hashable, ...], @@ -229,7 +229,7 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - _check_coords_dims(shape, new_coords, dims_tuple) + check_dataarray_coords(shape, new_coords, dims_tuple) return new_coords, dims_tuple From 818b7f5d2d2cb6d8edaa09d085acd08e2c42e979 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 11:53:58 +0100 Subject: [PATCH 07/25] docstring tweaks --- xarray/core/indexes.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fbd2edefe68..f68db1bbbff 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -208,17 +208,17 @@ def validate_dataarray_coord( Dataset) or updating an existing one. By default raises an error if the dimensions of the coordinate variable - do conflict with the array dimensions (strict DataArray model). + do conflict with the array dimensions (DataArray model). This method may be overridden in Index subclasses, e.g., to include index - coordinates that does not conform with the strict DataArray model. This + coordinates that does not strictly conform with the DataArray model. This is useful for example to include (n+1)-dimensional cell boundary coordinates attached to an index. - When a DataArray is constructed from a Dataset, instead of raising the - error Xarray will drop the index and propagate the index coordinates - according to the default rules for DataArray (i.e., depending on their - dimensions). + When a DataArray is constructed from a Dataset, if the validation fails + Xarray will fail back to propagating the coordinate according to the + default rules for DataArray (i.e., depending on its dimensions), which + may drop this index. Parameters ---------- From e8df9b50aa5189d9cab0a25f903a2076237506bc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Mar 2025 09:20:56 +0100 Subject: [PATCH 08/25] add tests --- xarray/tests/test_dataarray.py | 49 ++++++++++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 25 +++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 75d6d919e19..a44bff66409 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -529,6 +529,30 @@ class CustomIndex(Index): ... # test coordinate variables copied assert da.coords["x"] is not coords.variables["x"] + def test_constructor_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + actual = DataArray([1.0, 2.0], coords=coords, dims="x") + + # cannot use `assert_identical()` test utility function here yet + # (indexes invariant check is still based on IndexVariable, which + # doesn't work with AnyIndex coordinate variables here) + assert actual.coords.to_dataset().equals(coords.to_dataset()) + assert list(actual.coords.xindexes) == list(coords.xindexes) + assert "x_bnds" not in actual.dims + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -1634,6 +1658,31 @@ def test_assign_coords_no_default_index(self) -> None: assert_identical(actual.coords, coords, check_default_indexes=False) assert "y" not in actual.xindexes + def test_assign_coords_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + da = DataArray([1.0, 2.0], dims="x") + actual = da.assign_coords(coords) + + # cannot use `assert_identical()` test utility function here yet + # (indexes invariant check is still based on IndexVariable, which + # doesn't work with AnyIndex coordinate variables here) + assert actual.coords.to_dataset().equals(coords.to_dataset()) + assert list(actual.coords.xindexes) == list(coords.xindexes) + assert "x_bnds" not in actual.dims + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b273b7d1a0d..7004d2a83ec 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4206,6 +4206,31 @@ def test_getitem_multiple_dtype(self) -> None: dataset = Dataset({key: ("dim0", range(1)) for key in keys}) assert_identical(dataset, dataset[keys]) + def test_getitem_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) + actual = ds["foo"] + + # cannot use `assert_identical()` test utility function here yet + # (indexes invariant check is still based on IndexVariable, which + # doesn't work with AnyIndex coordinate variables here) + assert actual.coords.to_dataset().equals(coords.to_dataset()) + assert list(actual.coords.xindexes) == list(coords.xindexes) + assert "x_bnds" not in actual.dims + def test_virtual_variables_default_coords(self) -> None: dataset = Dataset({"foo": ("x", range(10))}) expected1 = DataArray(range(10), dims="x", name="x") From 678c013f3dea163674f8dba058fc65fe41d816ab Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 14 Mar 2025 11:48:46 +0100 Subject: [PATCH 09/25] assert invariants: skip check IndexVariable ... ... when check_default_indexes=False. --- xarray/testing/assertions.py | 22 +++++++++++++++------- xarray/tests/test_dataarray.py | 12 ++---------- xarray/tests/test_dataset.py | 6 +----- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 8a2dba9261f..ec7b4fdd410 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks( k: type(v) for k, v in indexes.items() } - index_vars = { - k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable) - } - assert indexes.keys() <= index_vars, (set(indexes), index_vars) + if check_default: + index_vars = { + k + for k, v in possible_coord_variables.items() + if isinstance(v, IndexVariable) + } + assert indexes.keys() <= index_vars, (set(indexes), index_vars) # check pandas index wrappers vs. coordinate data adapters for k, index in indexes.items(): @@ -399,9 +402,14 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): da.dims, {k: v.dims for k, v in da._coords.items()}, ) - assert all( - isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,) - ), {k: type(v) for k, v in da._coords.items()} + + if check_default_indexes: + assert all( + isinstance(v, IndexVariable) + for (k, v) in da._coords.items() + if v.dims == (k,) + ), {k: type(v) for k, v in da._coords.items()} + for k, v in da._coords.items(): _assert_variable_invariants(v, k) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a44bff66409..1883b9eb407 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -546,11 +546,7 @@ class AnyIndex(Index): actual = DataArray([1.0, 2.0], coords=coords, dims="x") - # cannot use `assert_identical()` test utility function here yet - # (indexes invariant check is still based on IndexVariable, which - # doesn't work with AnyIndex coordinate variables here) - assert actual.coords.to_dataset().equals(coords.to_dataset()) - assert list(actual.coords.xindexes) == list(coords.xindexes) + assert_identical(actual.coords, coords, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_equals_and_identical(self) -> None: @@ -1676,11 +1672,7 @@ class AnyIndex(Index): da = DataArray([1.0, 2.0], dims="x") actual = da.assign_coords(coords) - # cannot use `assert_identical()` test utility function here yet - # (indexes invariant check is still based on IndexVariable, which - # doesn't work with AnyIndex coordinate variables here) - assert actual.coords.to_dataset().equals(coords.to_dataset()) - assert list(actual.coords.xindexes) == list(coords.xindexes) + assert_identical(actual.coords, coords, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_coords_alignment(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7004d2a83ec..6ee81f02c85 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4224,11 +4224,7 @@ class AnyIndex(Index): ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) actual = ds["foo"] - # cannot use `assert_identical()` test utility function here yet - # (indexes invariant check is still based on IndexVariable, which - # doesn't work with AnyIndex coordinate variables here) - assert actual.coords.to_dataset().equals(coords.to_dataset()) - assert list(actual.coords.xindexes) == list(coords.xindexes) + assert_identical(actual.coords, coords, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_virtual_variables_default_coords(self) -> None: From 0f822b550bfda45ece0bfeb58030e1ac3a22276a Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 11:58:59 +0100 Subject: [PATCH 10/25] update cherry-picked tests --- xarray/tests/test_dataarray.py | 12 ++++++------ xarray/tests/test_dataset.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 1883b9eb407..bdfb34851f0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -531,9 +531,9 @@ class CustomIndex(Index): ... def test_constructor_extra_dim_index_coord(self) -> None: class AnyIndex(Index): - # This test only requires that the coordinates to assign have an - # index, whatever its type. - pass + def validate_dataarray_coord(self, name, var, dims): + # pass all index coordinates + pass idx = AnyIndex() coords = Coordinates( @@ -1656,9 +1656,9 @@ def test_assign_coords_no_default_index(self) -> None: def test_assign_coords_extra_dim_index_coord(self) -> None: class AnyIndex(Index): - # This test only requires that the coordinates to assign have an - # index, whatever its type. - pass + def validate_dataarray_coord(self, name, var, dims): + # pass all index coordinates + pass idx = AnyIndex() coords = Coordinates( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 6ee81f02c85..d3a51c404e4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4208,9 +4208,9 @@ def test_getitem_multiple_dtype(self) -> None: def test_getitem_extra_dim_index_coord(self) -> None: class AnyIndex(Index): - # This test only requires that the coordinates to assign have an - # index, whatever its type. - pass + def validate_dataarray_coord(self, name, var, dims): + # pass all index coordinates + pass idx = AnyIndex() coords = Coordinates( From 43c44eade449c521da430d4fa1d8a3792c7e79bb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 12:06:51 +0100 Subject: [PATCH 11/25] update assert datarray invariants --- xarray/testing/assertions.py | 8 ++++---- xarray/tests/test_dataarray.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index ec7b4fdd410..15a239894fb 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -398,12 +398,12 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): assert isinstance(da._coords, dict), da._coords assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords - assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( - da.dims, - {k: v.dims for k, v in da._coords.items()}, - ) if check_default_indexes: + assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( + da.dims, + {k: v.dims for k, v in da._coords.items()}, + ) assert all( isinstance(v, IndexVariable) for (k, v) in da._coords.items() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index bdfb34851f0..50208a59b7c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1671,8 +1671,9 @@ def validate_dataarray_coord(self, name, var, dims): da = DataArray([1.0, 2.0], dims="x") actual = da.assign_coords(coords) + expected = DataArray([1.0, 2.0], coords=coords, dims="x") - assert_identical(actual.coords, coords, check_default_indexes=False) + assert_identical(actual, expected, check_default_indexes=False) assert "x_bnds" not in actual.dims def test_coords_alignment(self) -> None: From 3b332630bdcc8c2317304e39741ad7d3bb1f495e Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 12:07:13 +0100 Subject: [PATCH 12/25] doc: add Index.validate_dataarray_coords to API --- doc/api-hidden.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index ac8290b3d1b..8a225b524d5 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -520,6 +520,7 @@ Index.stack Index.unstack Index.create_variables + Index.validate_dataarray_coords Index.to_pandas_index Index.isel Index.sel From a8e6e20b3f149c13920c7c6da8a882cc30d3c2c6 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 12:46:46 +0100 Subject: [PATCH 13/25] typo --- doc/api-hidden.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 8a225b524d5..d28103d0cb7 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -520,7 +520,7 @@ Index.stack Index.unstack Index.create_variables - Index.validate_dataarray_coords + Index.validate_dataarray_coord Index.to_pandas_index Index.isel Index.sel From f1440c449a082e2791bbe13e74a0fff8a7873297 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 Mar 2025 12:50:42 +0100 Subject: [PATCH 14/25] update whats new --- doc/whats-new.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 701d4583512..9971fc428c8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,10 @@ New Features By `Benoit Bovy `_. - Support reading to `GPU memory with Zarr `_ (:pull:`10078`). By `Deepak Cherian `_. +- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding + :py:meth:`Index.validate_dataarray_coord`. For example, this enables support for CF boundaries coordinate (e.g., + ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`). + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ From 5da014e260576b8c0579b312e0bde42a6461e5eb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 18 Mar 2025 12:09:25 +0100 Subject: [PATCH 15/25] add CoordinateValidationError --- doc/api.rst | 1 + xarray/__init__.py | 3 ++- xarray/core/dataarray.py | 5 +++-- xarray/core/dataset.py | 3 ++- xarray/core/indexes.py | 8 ++++++-- xarray/tests/test_dataarray.py | 21 +++++++++++++++------ 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index d7c2370d348..465da6985ee 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1639,6 +1639,7 @@ Exceptions .. autosummary:: :toctree: generated/ + CoordinateValidationError MergeError SerializationWarning diff --git a/xarray/__init__.py b/xarray/__init__.py index 3fd57e1a335..dd781c72a17 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -39,7 +39,7 @@ register_dataset_accessor, register_datatree_accessor, ) -from xarray.core.indexes import Index +from xarray.core.indexes import CoordinateValidationError, Index from xarray.core.indexing import IndexSelResult from xarray.core.options import get_options, set_options from xarray.core.parallel import map_blocks @@ -128,6 +128,7 @@ "NamedArray", "Variable", # Exceptions + "CoordinateValidationError", "InvalidTreeError", "MergeError", "NotFoundInTreeError", diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6e1d21fdb45..0ee92911960 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -47,6 +47,7 @@ from xarray.core.extension_array import PandasExtensionArray from xarray.core.formatting import format_item from xarray.core.indexes import ( + CoordinateValidationError, Index, Indexes, PandasMultiIndex, @@ -151,7 +152,7 @@ def check_dataarray_coords( if k in indexes: indexes[k].validate_dataarray_coord(k, v, dim_set) elif any(d not in dim for d in v.dims): - raise ValueError( + raise CoordinateValidationError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " f"dimensions {dim}" @@ -159,7 +160,7 @@ def check_dataarray_coords( for d, s in v.sizes.items(): if d in sizes and s != sizes[d]: - raise ValueError( + raise CoordinateValidationError( f"conflicting sizes for dimension {d!r}: " f"length {sizes[d]} on the data but length {s} on " f"coordinate {k!r}" diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2db8229be31..691983dcc4e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -65,6 +65,7 @@ ) from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.indexes import ( + CoordinateValidationError, Index, Indexes, PandasIndex, @@ -1219,7 +1220,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: k, self._variables[k], needed_dims ) coords[k] = self._variables[k] - except ValueError: + except CoordinateValidationError: # failback to strict DataArray model check (index may be dropped later) if var_dims <= needed_dims: coords[k] = self._variables[k] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f68db1bbbff..414a825e94b 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -33,6 +33,10 @@ IndexVars = dict[Any, "Variable"] +class CoordinateValidationError(ValueError): + """Error class for Xarray coordinate validation failures.""" + + class Index: """ Base class inherited by all xarray-compatible indexes. @@ -231,12 +235,12 @@ def validate_dataarray_coord( Raises ------ - ValueError + CoordinateValidationError When validation fails. """ if any(d not in dims for d in var.dims): - raise ValueError( + raise CoordinateValidationError( f"coordinate {name} has dimensions {var.dims}, but these " "are not a subset of the DataArray " f"dimensions {dims}" diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 50208a59b7c..77e923bcca2 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -34,7 +34,12 @@ from xarray.core import dtypes from xarray.core.common import full_like from xarray.core.coordinates import Coordinates -from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords +from xarray.core.indexes import ( + CoordinateValidationError, + Index, + PandasIndex, + filter_indexes_from_coords, +) from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants @@ -418,9 +423,13 @@ def test_constructor_invalid(self) -> None: with pytest.raises(TypeError, match=r"is not hashable"): DataArray(data, dims=["x", []]) # type: ignore[list-item] - with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + with pytest.raises( + CoordinateValidationError, match=r"conflicting sizes for dim" + ): DataArray([1, 2, 3], coords=[("x", [0, 1])]) - with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + with pytest.raises( + CoordinateValidationError, match=r"conflicting sizes for dim" + ): DataArray([1, 2], coords={"x": [0, 1], "y": ("x", [1])}, dims="x") with pytest.raises(ValueError, match=r"conflicting MultiIndex"): @@ -1622,11 +1631,11 @@ def test_assign_coords(self) -> None: # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da["x"] = [0, 1, 2, 3] # size conflict - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da.coords["x"] = [0, 1, 2, 3] # size conflict - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray def test_assign_coords_existing_multiindex(self) -> None: From 602665655153e554326eef105c0038a83830bfe0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 18 Mar 2025 12:09:42 +0100 Subject: [PATCH 16/25] docstrings tweaks --- xarray/core/indexes.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 414a825e94b..a5c7b7ca273 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -205,29 +205,33 @@ def validate_dataarray_coord( var: Variable, dims: set[Hashable], ): - """Validate an index coordinate to be included in a DataArray. + """Validate an index coordinate variable to include in a DataArray. This method is called repeatedly for each coordinate associated with this index when creating a new DataArray (via its constructor or from a Dataset) or updating an existing one. - By default raises an error if the dimensions of the coordinate variable - do conflict with the array dimensions (DataArray model). + By default raises a :py:class:`CoordinateValidationError` if the + dimensions of the coordinate variable do conflict with the array + dimensions (DataArray model). - This method may be overridden in Index subclasses, e.g., to include index - coordinates that does not strictly conform with the DataArray model. This - is useful for example to include (n+1)-dimensional cell boundary - coordinates attached to an index. + This method may be overridden in Index subclasses, e.g., to validate + index coordinates even when they do not strictly conform with the + DataArray model. This is useful for example to include (n+1)-dimensional + cell boundary coordinates attached to an index. - When a DataArray is constructed from a Dataset, if the validation fails - Xarray will fail back to propagating the coordinate according to the - default rules for DataArray (i.e., depending on its dimensions), which - may drop this index. + If the validation passes (i.e., no error raised), the coordinate will be + included in the DataArray regardless of its dimensions. + + When a DataArray is constructed from a Dataset (variable access), if the + validation fails Xarray will fail back to propagating the coordinate + according to the default rules for DataArray (i.e., depending on its + dimensions), which may drop this index. Parameters ---------- name : Hashable - Name of a coordinate associated to this index. + Name of a coordinate variable associated to this index. var : Variable Coordinate variable object. dims: tuple From 1eeec9cbb076a75038cc8491019c0810cee1d971 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 18 Mar 2025 12:29:42 +0100 Subject: [PATCH 17/25] nit refactor Functions with a leading underscore are marked by pyright as unused if they are not used from within the module in which they are defined. Also remove unneeded nested import. --- xarray/core/coordinates.py | 3 +-- xarray/core/groupby.py | 4 ++-- xarray/groupers.py | 12 ++++++------ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 220321e9442..046a1bbcfc9 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1148,9 +1148,8 @@ def create_coords_with_default_indexes( return new_coords -def _coordinates_from_variable(variable: Variable) -> Coordinates: - from xarray.core.indexes import create_default_index_implicit +def coordinates_from_variable(variable: Variable) -> Coordinates: (name,) = variable.dims new_index, index_vars = create_default_index_implicit(variable) indexes = {k: new_index for k in index_vars} diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a0540d3a1b2..b081c802d2e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -23,7 +23,7 @@ DatasetGroupByAggregations, ) from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce -from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( @@ -1119,7 +1119,7 @@ def _flox_reduce( new_coords.append( # Using IndexVariable here ensures we reconstruct PandasMultiIndex with # all associated levels properly. - _coordinates_from_variable( + coordinates_from_variable( IndexVariable( dims=grouper.name, data=output_index, diff --git a/xarray/groupers.py b/xarray/groupers.py index 234c9f1398a..82e3db0466d 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -18,7 +18,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.computation.computation import apply_ufunc -from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import array_all, isnull from xarray.core.groupby import T_Group, _DummyGroup @@ -115,7 +115,7 @@ def __init__( if coords is None: assert not isinstance(self.unique_coord, _DummyGroup) - self.coords = _coordinates_from_variable(self.unique_coord) + self.coords = coordinates_from_variable(self.unique_coord) else: self.coords = coords @@ -248,7 +248,7 @@ def _factorize_unique(self) -> EncodedGroups: codes=codes, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) def _factorize_dummy(self) -> EncodedGroups: @@ -276,7 +276,7 @@ def _factorize_dummy(self) -> EncodedGroups: else: if TYPE_CHECKING: assert isinstance(unique_coord, Variable) - coords = _coordinates_from_variable(unique_coord) + coords = coordinates_from_variable(unique_coord) return EncodedGroups( codes=codes, @@ -405,7 +405,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: codes=codes, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) @@ -539,7 +539,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: group_indices=group_indices, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) From 426ddce6fd4ab5946212f4dad743af4d2fe4c2fd Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 18 Mar 2025 12:54:21 +0100 Subject: [PATCH 18/25] small refactor Move check_dataarray_coords in xarray.core.coordinates module and rename it to validate_dataarray_coords (name consistent with Index.validate_dataarray_coord). Move CoordinateValidationError from xarray.core.indexes to xarray.core.coordinates module. --- xarray/__init__.py | 4 +-- xarray/core/coordinates.py | 51 ++++++++++++++++++++++++++++++++-- xarray/core/dataarray.py | 38 ++----------------------- xarray/core/dataset.py | 2 +- xarray/core/indexes.py | 6 ++-- xarray/tests/test_dataarray.py | 3 +- 6 files changed, 56 insertions(+), 48 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index dd781c72a17..387127437b9 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -29,7 +29,7 @@ ) from xarray.conventions import SerializationWarning, decode_cf from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, CoordinateValidationError from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree @@ -39,7 +39,7 @@ register_dataset_accessor, register_datatree_accessor, ) -from xarray.core.indexes import CoordinateValidationError, Index +from xarray.core.indexes import Index from xarray.core.indexing import IndexSelResult from xarray.core.options import get_options, set_options from xarray.core.parallel import map_blocks diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 046a1bbcfc9..686314ab7dc 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -966,9 +966,7 @@ def __getitem__(self, key: Hashable) -> T_DataArray: def _update_coords( self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: - from xarray.core.dataarray import check_dataarray_coords - - check_dataarray_coords( + validate_dataarray_coords( self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims ) @@ -1148,6 +1146,53 @@ def create_coords_with_default_indexes( return new_coords +class CoordinateValidationError(ValueError): + """Error class for Xarray coordinate validation failures.""" + + +def validate_dataarray_coords( + shape: tuple[int, ...], + coords: Coordinates | Mapping[Hashable, Variable], + dim: tuple[Hashable, ...], +): + """Validate coordinates ``coords`` to include in a DataArray defined by + ``shape`` and dimensions ``dim``. + + If a coordinate is associated with an index, the validation is performed by + the index. By default the coordinate dimensions must match (a subset of) the + array dimensions (in any order) to conform to the DataArray model. The index + may override this behavior with other validation rules, though. + + Non-index coordinates must all conform to the DataArray model. Scalar + coordinates are always valid. + """ + sizes = dict(zip(dim, shape, strict=True)) + dim_set = set(dim) + + indexes: Mapping[Hashable, Index] + if isinstance(coords, Coordinates): + indexes = coords.xindexes + else: + indexes = {} + + for k, v in coords.items(): + if k in indexes: + indexes[k].validate_dataarray_coord(k, v, dim_set) + elif any(d not in dim for d in v.dims): + raise CoordinateValidationError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dim}" + ) + + for d, s in v.sizes.items(): + if d in sizes and s != sizes[d]: + raise CoordinateValidationError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + def coordinates_from_variable(variable: Variable) -> Coordinates: (name,) = variable.dims diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0ee92911960..0c0e8732a7d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -42,12 +42,12 @@ DataArrayCoordinates, assert_coordinate_consistent, create_coords_with_default_indexes, + validate_dataarray_coords, ) from xarray.core.dataset import Dataset from xarray.core.extension_array import PandasExtensionArray from xarray.core.formatting import format_item from xarray.core.indexes import ( - CoordinateValidationError, Index, Indexes, PandasMultiIndex, @@ -133,40 +133,6 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def check_dataarray_coords( - shape: tuple[int, ...], - coords: Coordinates | Mapping[Hashable, Variable], - dim: tuple[Hashable, ...], -): - sizes = dict(zip(dim, shape, strict=True)) - - indexes: Mapping[Hashable, Index] - if isinstance(coords, Coordinates): - indexes = coords.xindexes - else: - indexes = {} - - dim_set = set(dim) - - for k, v in coords.items(): - if k in indexes: - indexes[k].validate_dataarray_coord(k, v, dim_set) - elif any(d not in dim for d in v.dims): - raise CoordinateValidationError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dim}" - ) - - for d, s in v.sizes.items(): - if d in sizes and s != sizes[d]: - raise CoordinateValidationError( - f"conflicting sizes for dimension {d!r}: " - f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" - ) - - def _infer_coords_and_dims( shape: tuple[int, ...], coords: ( @@ -230,7 +196,7 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - check_dataarray_coords(shape, new_coords, dims_tuple) + validate_dataarray_coords(shape, new_coords, dims_tuple) return new_coords, dims_tuple diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 691983dcc4e..805e1277abc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -60,12 +60,12 @@ ) from xarray.core.coordinates import ( Coordinates, + CoordinateValidationError, DatasetCoordinates, assert_coordinate_consistent, ) from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.indexes import ( - CoordinateValidationError, Index, Indexes, PandasIndex, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a5c7b7ca273..9ce8d85a66a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -33,10 +33,6 @@ IndexVars = dict[Any, "Variable"] -class CoordinateValidationError(ValueError): - """Error class for Xarray coordinate validation failures.""" - - class Index: """ Base class inherited by all xarray-compatible indexes. @@ -243,6 +239,8 @@ def validate_dataarray_coord( When validation fails. """ + from xarray.core.coordinates import CoordinateValidationError + if any(d not in dims for d in var.dims): raise CoordinateValidationError( f"coordinate {name} has dimensions {var.dims}, but these " diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 77e923bcca2..a236b7c17f1 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -33,9 +33,8 @@ from xarray.coders import CFDatetimeCoder from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, CoordinateValidationError from xarray.core.indexes import ( - CoordinateValidationError, Index, PandasIndex, filter_indexes_from_coords, From 43990364b5db03a73eeeaea2ddf43cabe4a3ac0d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 31 Mar 2025 09:14:22 +0200 Subject: [PATCH 19/25] docstrings improvements Co-authored-by: Deepak Cherian --- xarray/core/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9ce8d85a66a..0ddc486d16a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -203,7 +203,7 @@ def validate_dataarray_coord( ): """Validate an index coordinate variable to include in a DataArray. - This method is called repeatedly for each coordinate associated with + This method is called repeatedly for each Variable associated with this index when creating a new DataArray (via its constructor or from a Dataset) or updating an existing one. From 828a4cc082202a144e4e1f257a7829bf4421d479 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 31 Mar 2025 09:14:32 +0200 Subject: [PATCH 20/25] docstrings improvements Co-authored-by: Deepak Cherian --- xarray/core/indexes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0ddc486d16a..3c5771d70d6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -219,10 +219,10 @@ def validate_dataarray_coord( If the validation passes (i.e., no error raised), the coordinate will be included in the DataArray regardless of its dimensions. - When a DataArray is constructed from a Dataset (variable access), if the - validation fails Xarray will fail back to propagating the coordinate - according to the default rules for DataArray (i.e., depending on its - dimensions), which may drop this index. + If this method raises when a DataArray is constructed from a Dataset, + Xarray will fail back to propagating the coordinate + according to the default rules for DataArray --- i.e., the dimensions of every + coordinate variable must be a subset of DataArray.dims --- which may drop this index. Parameters ---------- From 273d70c21782388f4ba2d8ebd4518a44c422f303 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 07:15:09 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/indexes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 3c5771d70d6..ee5976199ed 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -219,9 +219,9 @@ def validate_dataarray_coord( If the validation passes (i.e., no error raised), the coordinate will be included in the DataArray regardless of its dimensions. - If this method raises when a DataArray is constructed from a Dataset, + If this method raises when a DataArray is constructed from a Dataset, Xarray will fail back to propagating the coordinate - according to the default rules for DataArray --- i.e., the dimensions of every + according to the default rules for DataArray --- i.e., the dimensions of every coordinate variable must be a subset of DataArray.dims --- which may drop this index. Parameters From 3e55af003bbd4b73402a2bd6ae201e7da6972afd Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 24 Apr 2025 11:35:47 +0200 Subject: [PATCH 22/25] refactor index check method Instead of adding the ``Index.validate_dataarray_coord`` method that may raise an error, add ``Index.should_add_coord_in_datarray`` method that returns a boolean. --- xarray/core/coordinates.py | 7 +++-- xarray/core/dataset.py | 18 +++++-------- xarray/core/indexes.py | 47 ++++++++++++++-------------------- xarray/tests/test_dataarray.py | 10 +++----- xarray/tests/test_dataset.py | 5 ++-- 5 files changed, 36 insertions(+), 51 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index be5c80145e9..88fac53c781 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1177,8 +1177,11 @@ def validate_dataarray_coords( for k, v in coords.items(): if k in indexes: - indexes[k].validate_dataarray_coord(k, v, dim_set) - elif any(d not in dim for d in v.dims): + invalid = not indexes[k].should_add_coord_in_dataarray(k, v, dim_set) + else: + invalid = any(d not in dim for d in v.dims) + + if invalid: raise CoordinateValidationError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d2ea86871f8..bde5c5c1885 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -48,7 +48,6 @@ ) from xarray.core.coordinates import ( Coordinates, - CoordinateValidationError, DatasetCoordinates, assert_coordinate_consistent, ) @@ -1161,17 +1160,12 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: # preserve ordering for k in self._variables: var_dims = set(self._variables[k].dims) - if k in self._indexes: - try: - self._indexes[k].validate_dataarray_coord( - k, self._variables[k], needed_dims - ) - coords[k] = self._variables[k] - except CoordinateValidationError: - # failback to strict DataArray model check (index may be dropped later) - if var_dims <= needed_dims: - coords[k] = self._variables[k] - elif k in self._coord_names and var_dims <= needed_dims: + if ( + k in self._indexes + and self._indexes[k].should_add_coord_in_dataarray( + k, self._variables[k], needed_dims + ) + ) or (k in self._coord_names and var_dims <= needed_dims): coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5605531fdb1..18987846b76 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -195,34 +195,34 @@ def create_variables( else: return {} - def validate_dataarray_coord( + def should_add_coord_in_dataarray( self, name: Hashable, var: Variable, dims: set[Hashable], - ): - """Validate an index coordinate variable to include in a DataArray. + ) -> bool: + """Define whether or not an index coordinate variable should be added in + a new DataArray. This method is called repeatedly for each Variable associated with this index when creating a new DataArray (via its constructor or from a Dataset) or updating an existing one. - By default raises a :py:class:`CoordinateValidationError` if the - dimensions of the coordinate variable do conflict with the array - dimensions (DataArray model). + By default returns ``True`` if the dimensions of the coordinate variable + are a subset of the array dimensions and ``False`` otherwise (DataArray + model). This default behavior may be overridden in Index subclasses to + bypass strict conformance with the DataArray model. This is useful for + example to include the (n+1)-dimensional cell boundary coordinate + associated with an interval index. - This method may be overridden in Index subclasses, e.g., to validate - index coordinates even when they do not strictly conform with the - DataArray model. This is useful for example to include (n+1)-dimensional - cell boundary coordinates attached to an index. + Returning ``False`` will either: - If the validation passes (i.e., no error raised), the coordinate will be - included in the DataArray regardless of its dimensions. + - raise a :py:class:`CoordinateValidationError` when passing the + coordinate directly to a new or an existing DataArray, e.g., via + ``DataArray.__init__()`` or ``DataArray.assign_coords()`` - If this method raises when a DataArray is constructed from a Dataset, - Xarray will fail back to propagating the coordinate - according to the default rules for DataArray --- i.e., the dimensions of every - coordinate variable must be a subset of DataArray.dims --- which may drop this index. + - drop the coordinate --- and maybe drop the index too --- when a new + DataArray is constructed by indexing a Dataset Parameters ---------- @@ -233,20 +233,11 @@ def validate_dataarray_coord( dims: tuple Dataarray's dimensions. - Raises - ------ - CoordinateValidationError - When validation fails. - """ - from xarray.core.coordinates import CoordinateValidationError - if any(d not in dims for d in var.dims): - raise CoordinateValidationError( - f"coordinate {name} has dimensions {var.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dims}" - ) + return False + else: + return True def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 1de788f99b0..810e02f7b99 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -539,9 +539,8 @@ class CustomIndex(Index): ... def test_constructor_extra_dim_index_coord(self) -> None: class AnyIndex(Index): - def validate_dataarray_coord(self, name, var, dims): - # pass all index coordinates - pass + def should_add_coord_in_dataarray(self, name, var, dims): + return True idx = AnyIndex() coords = Coordinates( @@ -1664,9 +1663,8 @@ def test_assign_coords_no_default_index(self) -> None: def test_assign_coords_extra_dim_index_coord(self) -> None: class AnyIndex(Index): - def validate_dataarray_coord(self, name, var, dims): - # pass all index coordinates - pass + def should_add_coord_in_dataarray(self, name, var, dims): + return True idx = AnyIndex() coords = Coordinates( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2ee5f9a0d20..0b59d9c0cbc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4235,9 +4235,8 @@ def test_getitem_multiple_dtype(self) -> None: def test_getitem_extra_dim_index_coord(self) -> None: class AnyIndex(Index): - def validate_dataarray_coord(self, name, var, dims): - # pass all index coordinates - pass + def should_add_coord_in_dataarray(self, name, var, dims): + return True idx = AnyIndex() coords = Coordinates( From 073c0a2baf04225402f05ff85a9e7d289f812f8f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 24 Apr 2025 11:53:53 +0200 Subject: [PATCH 23/25] small refactor Also do not add a coordinate in a DataArray indexed from a Dataset if ``index.should_add_coord_in_datarray`` returns False, even if the coordinate variable dimensions are compatible with the DataArray dimensions (note: there's no test for this case yet, I don't know if it represents any real use case). --- xarray/core/dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bde5c5c1885..7c8ce1de5e1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1159,13 +1159,15 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - var_dims = set(self._variables[k].dims) - if ( - k in self._indexes - and self._indexes[k].should_add_coord_in_dataarray( + if k in self._indexes: + add_coord = self._indexes[k].should_add_coord_in_dataarray( k, self._variables[k], needed_dims ) - ) or (k in self._coord_names and var_dims <= needed_dims): + else: + var_dims = set(self._variables[k].dims) + add_coord = k in self._coord_names and var_dims <= needed_dims + + if add_coord: coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) From 8d43dcc3457efebd6cfd9da605c96335cc54892f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 24 Apr 2025 12:02:13 +0200 Subject: [PATCH 24/25] forgot updating API docs and whats new --- doc/api-hidden.rst | 2 +- doc/whats-new.rst | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 3a067b0321f..7bc3cd44f09 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -520,7 +520,7 @@ Index.stack Index.unstack Index.create_variables - Index.validate_dataarray_coord + Index.should_add_coord_in_dataarray Index.to_pandas_index Index.isel Index.sel diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b356d46a02e..076d9211818 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,6 +27,10 @@ New Features - Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This includes ``datatree`` support, and removing slashes from dimension names. By `Miguel Jimenez-Urias `_. +- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding + :py:meth:`Index.should_add_coord_in_dataarray`. For example, this enables support for CF boundaries coordinate (e.g., + ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`). + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -136,10 +140,6 @@ New Features (:pull:`9498`). By `Spencer Clark `_. - Support reading to `GPU memory with Zarr `_ (:pull:`10078`). By `Deepak Cherian `_. -- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding - :py:meth:`Index.validate_dataarray_coord`. For example, this enables support for CF boundaries coordinate (e.g., - ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`). - By `Benoit Bovy `_. Performance ~~~~~~~~~~~ From 4e7c70a9486290c310e735fa0e4a732e1ea3abad Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 24 Apr 2025 12:30:07 +0200 Subject: [PATCH 25/25] nit docstrings --- xarray/core/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 18987846b76..471fe75d398 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -221,7 +221,7 @@ def should_add_coord_in_dataarray( coordinate directly to a new or an existing DataArray, e.g., via ``DataArray.__init__()`` or ``DataArray.assign_coords()`` - - drop the coordinate --- and maybe drop the index too --- when a new + - drop the coordinate (and therefore drop the index) when a new DataArray is constructed by indexing a Dataset Parameters