Skip to content

Commit 9db951c

Browse files
committed
[WIP]: Force Zarr coordinate reads to be on the host
zarr-python 3.x supports reading data to host (CPU) memory or device (GPU) memory. Because coordinates are small and really do need to be on the host (IIUC because putting them in an Index) then there's no benefit to reading them to device. zarr-python includes a global config for whether to use host or device memory for reads, with `zarr.config.enable_gpu()`. But you can override that on a per-read basis by passing `prototype` to the getitem call. This does that for arrays that are coordinates.
1 parent 2475d49 commit 9db951c

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

xarray/backends/zarr.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,14 @@ def encode_zarr_attr_value(value):
180180

181181

182182
class ZarrArrayWrapper(BackendArray):
183-
__slots__ = ("_array", "dtype", "shape")
183+
__slots__ = ("_array", "dtype", "shape", "is_coordinate")
184184

185-
def __init__(self, zarr_array):
185+
def __init__(self, zarr_array, is_coordinate: bool):
186186
# some callers attempt to evaluate an array if an `array` property exists on the object.
187187
# we prefix with _ to avoid this inference.
188188
self._array = zarr_array
189189
self.shape = self._array.shape
190+
self.is_coordinate = is_coordinate
190191

191192
# preserve vlen string object dtype (GH 7328)
192193
if (
@@ -210,7 +211,12 @@ def _vindex(self, key):
210211
return self._array.vindex[key]
211212

212213
def _getitem(self, key):
213-
return self._array[key]
214+
from zarr.core.buffer.cpu import buffer_prototype
215+
if self.is_coordinate:
216+
prototype = buffer_prototype
217+
else:
218+
prototype = None
219+
return self._array.get_basic_selection(key, prototype=prototype)
214220

215221
def __getitem__(self, key):
216222
array = self._array
@@ -809,7 +815,8 @@ def ds(self):
809815

810816
def open_store_variable(self, name):
811817
zarr_array = self.members[name]
812-
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
818+
is_coordinate = name in zarr_array.metadata.dimension_names
819+
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array, is_coordinate=is_coordinate))
813820
try_nczarr = self._mode == "r"
814821
dimensions, attributes = _get_zarr_dims_and_attrs(
815822
zarr_array, DIMENSION_KEY, try_nczarr

0 commit comments

Comments
 (0)