diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 1bccc51bd43..e53c762e982 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -19,7 +19,7 @@ from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform -from xarray.core.nputils import NumpyVIndexAdapter +from xarray.core.nputils import NumpyVIndexAdapter, inverse_permutation from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray from xarray.core.utils import ( @@ -2172,3 +2172,198 @@ def _repr_inline_(self, max_width: int) -> str: from xarray.core.formatting import format_array_flat return format_array_flat(self._get_array_subset(), max_width) + + +######### +# These imports need to be removed (or vendored) +from dask.array.core import chunks_from_arrays, deepfirst, shapelist, slices_from_chunks +from dask.core import flatten +from dask.utils import cached_cumsum, ndimlist + + +def squeeze(array: np.ndarray, axes: tuple[int, ...]) -> np.ndarray: + squeezer = [slice(None)] * array.ndim + for ax in axes: + assert array.shape[ax] == 1 + squeezer[ax] = 0 + return array[tuple(squeezer)] + + +def atleast_ndim(array: np.ndarray, axis: int, ndim: int) -> np.ndarray: + idxr = [None] * ndim + idxr[ax] = slice(None) + return np.asarray(array)[tuple(idxr)] + + +def concatenate3(arrays): + """Recursive np.concatenate + + Input should be a nested list of numpy arrays arranged in the order they + should appear in the array itself. Each array should have the same number + of dimensions as the desired output and the nesting of the lists. + + >>> x = np.array([[1, 2]]) + >>> concatenate3([[x, x, x], [x, x, x]]) + array([[1, 2, 1, 2, 1, 2], + [1, 2, 1, 2, 1, 2]]) + + >>> concatenate3([[x, x], [x, x], [x, x]]) + array([[1, 2, 1, 2], + [1, 2, 1, 2], + [1, 2, 1, 2]]) + """ + ndim = ndimlist(arrays) + if not ndim: + return arrays + chunks = chunks_from_arrays(arrays) + shape = tuple(map(sum, chunks)) + + def dtype(x): + try: + return x.dtype + except AttributeError: + return type(x) + + result = np.empty(shape=shape, dtype=dtype(deepfirst(arrays))) + + for idx, arr in zip( + slices_from_chunks(chunks), + flatten(arrays, container=(list, tuple)), + strict=False, + ): + if hasattr(arr, "ndim"): + while arr.ndim < ndim: + arr = arr[None, ...] + # this `if` is the only change from dask. + # it is more relaxed and skips empty arrays + if arr.size > 0: + result[idx] = arr + + return result + + +class LazilyConcatenatedArray(ExplicitlyIndexedNDArrayMixin): + def __init__(self, arrays): + self.chunks = chunks_from_arrays(arrays) + self.chunkshape = shapelist(arrays) + self.breaks = tuple( + cached_cumsum(chunks, initial_zero=True) for chunks in self.chunks + ) + self.shape = tuple(breaks[-1] for breaks in self.breaks) + + self.arrays = np.empty(self.chunkshape, dtype=object) + for i, a in enumerate(flatten(arrays)): + self.arrays.flat[i] = as_indexable(a) + + self.dtype = self.arrays.flat[0].dtype + self._assert_invariants() + + def __array__(self) -> np.ndarray: + return np.block(self.arrays.tolist()) + + def get_duck_array(self) -> np.ndarray: + return np.array(self) + + def __getitem__(self, key: BasicIndexer | OuterIndexer): + print("__getitem__") + # numpy array of indexers to apply to each element of self.arrays + ordered_indexer = np.empty(self.chunkshape, dtype=object) + for i in np.ndindex(ordered_indexer.shape): + ordered_indexer[i] = [slice(None)] * self.ndim + finalizer = [slice(None)] * larry.ndim + mask = np.full(self.chunkshape, True) + + assert len(key) == self.ndim + squeeze_ax = [] + + for ax, idxr in enumerate(key): + breaks = self.breaks[ax] + if isinstidxr == slice(None): + for i in np.ndindex(ordered_indexer.shape): + ordered_indexer[i][ax] = slice(None) + continue + + elif isinstance(idxr, slice): + idxr = _normalize_slice(idxr, self.shape[ax]) + + elif is_scalar(idxr): + idxr = np.array([idxr]) + squeeze_ax.append(ax) + + elif not isinstance(idxr, np.ndarray): + idxr = np.array(idxr) + + ichunks = np.digitize(idxr, breaks) - 1 + unique_ichunks = set(ichunks.ravel()) + + # TODO: could optimize for already sorted + sorter = np.argsort(ichunks) + + for loc in np.ndindex(ordered_indexer.shape): + value = ordered_indexer[loc] + if loc[ax] in unique_ichunks: + value[ax] = atleast_ndim( + [ + v - self.breaks[ax][loc[ax]] + for v in idxr[ichunks == loc[ax]] + ], + ax, + ndim=self.ndim, + ) + ordered_indexer[loc] = value + else: + mask[loc] = False + + # need to sort back + invert_sorter = inverse_permutation(sorter, idxr.size) + finalizer[ax] = atleast_ndim(invert_sorter, axis=ax, ndim=self.ndim) + + print(ordered_indexer, finalizer, mask) + + arrays = self.arrays.copy() + for i in np.ndindex(ordered_indexer.shape): + arrays[i] = arrays[i][tuple(ordered_indexer[i])] + + empties = np.empty((1,) * self.ndim, dtype=object) + empties[(0,) * self.ndim] = np.empty((0,) * self.ndim, dtype=self.dtype) + empties = np.broadcast_to(empties, mask.shape) + + # rely on concatenate3 handling empty arrays, and complaining if the + # output shape is compatible. I think it is always compatible for + # outer indexing? + np.putmask(arrays, ~mask, empties) + + # TODO: lazify this. + result = squeeze( + concatenate3(arrays.tolist())[tuple(finalizer)], axes=squeeze_ax + ) + return result + + def _oindex_get(self, key: OuterIndexer) -> np.ndarray: + return self[key] + + def _vindex_get(self, key: VectorizedIndexer) -> np.ndarray: + ichunks = tuple( + np.digitize(breaks, idxr) - 1 + for breaks, idxr in zip(key, self.breaks, strict=False) + ) + linear_chunks = np.ravel_multi_index(ichunks) + unique_linear_chunks = np.unique(linear_chunks) + + # numpy array of indexers to apply to each element of self.arrays + ordered_indexer = np.empty(self.chunkshape, dtype=object) + for i in np.ndindex(ordered_indexer.shape): + ordered_indexer[i] = [slice(None)] * self.ndim + finalizer = [slice(None)] * larry.ndim + mask = np.full(self.chunkshape, False) + + out_arrays = [] + for chunk in unique_linear_chunks: + mask = linear_chunks == chunk + # INCOMPLETE HERE + out_arrays.append(self.arrays.flat[chunk].vindex) + + def _assert_invariants(self): + assert self.shape == np.array(self).shape + assert self.arrays.shape == self.chunkshape + assert all(np.issubdtype(a.dtype, self.dtype) for a in self.arrays.flat) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 36043e0c57b..d07db5137dc 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -12,6 +12,7 @@ from xarray.core import dtypes from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex +from xarray.core.indexing import LazilyConcatenatedArray, LazilyIndexedArray from xarray.structure import merge from xarray.tests import ( ConcatenatableArray, @@ -1381,3 +1382,14 @@ def test_concat_index_not_same_dim() -> None: match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*", ): concat([ds1, ds2], dim="x") + + +def test_lazy_concat(): + arrays_list = [ + [np.array([[1, 2, 3]]), np.array([[4, 5, 6]])], + [np.array([[7, 8, 9]]), np.array([[10, 11, 12]])], + [np.array([[7, 8, 9]]), np.array([[10, 11, 12]])], + ] + larry = LazilyIndexedArray(LazilyConcatenatedArray(arrays_list)) + nparray = np.block(arrays_list) + # assert npt.assert_array_equal here