diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a10a8c8851f..21f4b9584b5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,7 +55,9 @@ Bug fixes - Fix DataArray().drop_attrs(deep=False) and add support for attrs to DataArray()._replace(). (:issue:`10027`, :pull:`10030`). By `Jan Haacker `_. -- Fix ``isel`` for multi-coordinate Xarray indexes (:issue:`10063`, :pull:`10066`). +- Fix ``isel`` for multi-coordinate Xarray indexes and slightly improve + performance when repeatedly accessing :py:attr:`Dataset.xindexes` or + :py:attr:`DataArray.xindexes` (:issue:`10063`, :pull:`10066`, :pull:`10074`). By `Benoit Bovy `_. diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 47773ddfbb6..f4cc977557f 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -801,6 +801,8 @@ def _update_coords( original_indexes = dict(self._data.xindexes) original_indexes.update(indexes) self._data._indexes = original_indexes + # invalidate xindexes cache + self._data._xindexes = None def _drop_coords(self, coord_names): # should drop indexed coordinates only @@ -808,10 +810,14 @@ def _drop_coords(self, coord_names): del self._data._variables[name] del self._data._indexes[name] self._data._coord_names.difference_update(coord_names) + # invalidate xindexes cache + self._data._xindexes = None def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] + # invalidate xindexes cache + self._data._xindexes = None else: raise KeyError( f"{key!r} is not in coordinate variables {tuple(self.keys())}" @@ -980,6 +986,8 @@ def _update_coords( original_indexes = dict(self._data.xindexes) original_indexes.update(indexes) self._data._indexes = original_indexes + # invalidate xindexes cache + self._data._xindexes = None def _drop_coords(self, coord_names): # should drop indexed coordinates only @@ -987,6 +995,9 @@ def _drop_coords(self, coord_names): del self._data._coords[name] del self._data._indexes[name] + # invalidate xindexes cache + self._data._xindexes = None + @property def variables(self): return Frozen(self._data._coords) @@ -1008,6 +1019,8 @@ def __delitem__(self, key: Hashable) -> None: del self._data._coords[key] if key in self._data._indexes: del self._data._indexes[key] + # invalidate xindexes cache + self._data._xindexes = None def _ipython_key_completions_(self): """Provide method for the key-autocompletions in IPython.""" diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4ba0f0a73a2..65d848e0c74 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -421,6 +421,7 @@ class DataArray( _close: Callable[[], None] | None _indexes: dict[Hashable, Index] _name: Hashable | None + _xindexes: Indexes[Index] | None _variable: Variable __slots__ = ( @@ -431,6 +432,7 @@ class DataArray( "_indexes", "_name", "_variable", + "_xindexes", ) dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor["DataArray"]) @@ -496,6 +498,7 @@ def __init__( self._coords = coords self._name = name self._indexes = dict(indexes) + self._xindexes = None self._close = None @@ -515,6 +518,7 @@ def _construct_direct( obj._coords = coords obj._name = name obj._indexes = indexes + obj._xindexes = None obj._close = None return obj @@ -1004,7 +1008,11 @@ def xindexes(self) -> Indexes[Index]: """Mapping of :py:class:`~xarray.indexes.Index` objects used for label based indexing. """ - return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) + if self._xindexes is None: + self._xindexes = Indexes( + self._indexes, {k: self._coords[k] for k in self._indexes} + ) + return self._xindexes @property def coords(self) -> DataArrayCoordinates: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 449f502c43a..19038b963de 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -582,6 +582,7 @@ class Dataset( _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None _indexes: dict[Hashable, Index] + _xindexes: Indexes[Index] | None _variables: dict[Hashable, Variable] __slots__ = ( @@ -594,6 +595,7 @@ class Dataset( "_encoding", "_indexes", "_variables", + "_xindexes", ) def __init__( @@ -629,6 +631,7 @@ def __init__( self._coord_names = coord_names self._dims = dims self._indexes = indexes + self._xindexes = None # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping # related to https://github.com/python/mypy/issues/9319? @@ -1013,6 +1016,7 @@ def _construct_direct( obj._coord_names = coord_names obj._dims = dims obj._indexes = indexes + obj._xindexes = None obj._attrs = attrs obj._close = close obj._encoding = encoding @@ -1047,6 +1051,8 @@ def _replace( self._attrs = attrs if indexes is not None: self._indexes = indexes + # invalidate xindexes cache + self._xindexes = None if encoding is not _default: self._encoding = encoding obj = self @@ -1909,7 +1915,11 @@ def xindexes(self) -> Indexes[Index]: """Mapping of :py:class:`~xarray.indexes.Index` objects used for label based indexing. """ - return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + if self._xindexes is None: + self._xindexes = Indexes( + self._indexes, {k: self._variables[k] for k in self._indexes} + ) + return self._xindexes @property def coords(self) -> DatasetCoordinates: diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 61340ac99ad..68835f9cbe7 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -239,6 +239,7 @@ class DatasetView(Dataset): "_encoding", "_indexes", "_variables", + "_xindexes", ) def __init__( @@ -268,6 +269,7 @@ def _constructor( obj._coord_names = coord_names obj._dims = dims obj._indexes = indexes + obj._xindexes = None obj._attrs = attrs obj._close = close obj._encoding = encoding @@ -336,6 +338,7 @@ def _construct_direct( # type: ignore[override] obj._coord_names = coord_names obj._dims = dims obj._indexes = indexes + obj._xindexes = None obj._attrs = attrs obj._close = close obj._encoding = encoding diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 37711275bce..826cc47ae2d 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1568,6 +1568,9 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): _index_type: type[Index] | type[pd.Index] _indexes: dict[Any, T_PandasOrXarrayIndex] + _unique_index_coords: ( + list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]] | None + ) _variables: dict[Any, Variable] __slots__ = ( @@ -1577,6 +1580,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): "_dims", "_index_type", "_indexes", + "_unique_index_coords", "_variables", ) @@ -1625,6 +1629,7 @@ def __init__( self.__coord_name_id: dict[Any, int] | None = None self.__id_index: dict[int, T_PandasOrXarrayIndex] | None = None self.__id_coord_names: dict[int, tuple[Hashable, ...]] | None = None + self._unique_index_coords = None @property def _coord_name_id(self) -> dict[Any, int]: @@ -1743,12 +1748,14 @@ def group_by_index( ) -> list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]]: """Returns a list of unique indexes and their corresponding coordinates.""" - index_coords = [] - for i, index in self._id_index.items(): - coords = {k: self._variables[k] for k in self._id_coord_names[i]} - index_coords.append((index, coords)) + if self._unique_index_coords is None: + index_coords = [] + for i, index in self._id_index.items(): + coords = {k: self._variables[k] for k in self._id_coord_names[i]} + index_coords.append((index, coords)) - return index_coords + self._unique_index_coords = index_coords + return self._unique_index_coords def to_pandas_indexes(self) -> Indexes[pd.Index]: """Returns an immutable proxy for Dataset or DataArray pandas indexes. @@ -1933,36 +1940,6 @@ def check_variables(): return not not_equal -def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str): - # This function avoids the call to indexes.group_by_index - # which is really slow when repeatedly iterating through - # an array. However, it fails to return the correct ID for - # multi-index arrays - indexes_fast, coords = indexes._indexes, indexes._variables - - new_indexes: dict[Hashable, Index] = dict(indexes_fast.items()) - new_index_variables: dict[Hashable, Variable] = {} - for name, index in indexes_fast.items(): - coord = coords[name] - if hasattr(coord, "_indexes"): - index_vars = {n: coords[n] for n in coord._indexes} - else: - index_vars = {name: coord} - index_dims = {d for var in index_vars.values() for d in var.dims} - index_args = {k: v for k, v in args.items() if k in index_dims} - - if index_args: - new_index = getattr(index, func)(index_args) - if new_index is not None: - new_indexes.update({k: new_index for k in index_vars}) - new_index_vars = new_index.create_variables(index_vars) - new_index_variables.update(new_index_vars) - else: - for k in index_vars: - new_indexes.pop(k, None) - return new_indexes, new_index_variables - - def _apply_indexes( indexes: Indexes[Index], args: Mapping[Any, Any], @@ -1991,14 +1968,7 @@ def isel_indexes( indexes: Indexes[Index], indexers: Mapping[Any, Any], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - # Fast path function _apply_indexes_fast does not work with multi-coordinate - # Xarray indexes (see https://github.com/pydata/xarray/issues/10063). - # -> call it only in the most common case where all indexes are default - # PandasIndex each associated to a single 1-dimensional coordinate. - if any(type(idx) is not PandasIndex for idx in indexes._indexes.values()): - return _apply_indexes(indexes, indexers, "isel") - else: - return _apply_indexes_fast(indexes, indexers, "isel") + return _apply_indexes(indexes, indexers, "isel") def roll_indexes(