Skip to content

Commit d620357

Browse files
committed
fixups
(cherry picked from commit 48f85ed4e452709607a11a7b526e844fd3e41df3)
1 parent 9db951c commit d620357

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

xarray/backends/zarr.py

+63-10
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,31 @@ def encode_zarr_attr_value(value):
179179
return encoded
180180

181181

182+
def _is_coordinate_variable(zarr_array, name):
183+
if _zarr_v3():
184+
if zarr_array.metadata.zarr_format == 2:
185+
is_coordinate = name in zarr_array.metadata.attributes.get(
186+
"_ARRAY_DIMENSIONS", []
187+
)
188+
else:
189+
is_coordinate = name in (zarr_array.metadata.dimension_names or [])
190+
else:
191+
is_coordinate = name in zarr_array.attrs.get("_ARRAY_DIMENSIONS", [])
192+
return is_coordinate
193+
194+
182195
class ZarrArrayWrapper(BackendArray):
183-
__slots__ = ("_array", "dtype", "shape", "is_coordinate")
196+
__slots__ = ("_array", "coords_buffer_prototype", "dtype", "is_coordinate", "shape")
184197

185-
def __init__(self, zarr_array, is_coordinate: bool):
198+
def __init__(
199+
self, zarr_array, is_coordinate: bool, coords_buffer_prototype: Any | None
200+
):
186201
# some callers attempt to evaluate an array if an `array` property exists on the object.
187202
# we prefix with _ to avoid this inference.
188203
self._array = zarr_array
189204
self.shape = self._array.shape
190205
self.is_coordinate = is_coordinate
206+
self.coords_buffer_prototype = coords_buffer_prototype
191207

192208
# preserve vlen string object dtype (GH 7328)
193209
if (
@@ -211,12 +227,14 @@ def _vindex(self, key):
211227
return self._array.vindex[key]
212228

213229
def _getitem(self, key):
214-
from zarr.core.buffer.cpu import buffer_prototype
215-
if self.is_coordinate:
216-
prototype = buffer_prototype
217-
else:
218-
prototype = None
219-
return self._array.get_basic_selection(key, prototype=prototype)
230+
kwargs = {}
231+
if _zarr_v3():
232+
if self.is_coordinate:
233+
prototype = self.coords_buffer_prototype
234+
else:
235+
prototype = None
236+
kwargs["prototype"] = prototype
237+
return self._array.get_basic_selection(key, **kwargs)
220238

221239
def __getitem__(self, key):
222240
array = self._array
@@ -611,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore):
611629
"_cache_members",
612630
"_close_store_on_close",
613631
"_consolidate_on_close",
632+
"_coords_buffer_prototype",
614633
"_group",
615634
"_members",
616635
"_mode",
@@ -642,6 +661,7 @@ def open_store(
642661
use_zarr_fill_value_as_mask=None,
643662
write_empty: bool | None = None,
644663
cache_members: bool = True,
664+
coords_buffer_prototype: Any | None = None,
645665
):
646666
(
647667
zarr_group,
@@ -674,6 +694,7 @@ def open_store(
674694
close_store_on_close,
675695
use_zarr_fill_value_as_mask,
676696
cache_members=cache_members,
697+
coords_buffer_prototype=coords_buffer_prototype,
677698
)
678699
for group in group_paths
679700
}
@@ -697,6 +718,7 @@ def open_group(
697718
use_zarr_fill_value_as_mask=None,
698719
write_empty: bool | None = None,
699720
cache_members: bool = True,
721+
coords_buffer_prototype: Any | None = None,
700722
):
701723
(
702724
zarr_group,
@@ -728,6 +750,7 @@ def open_group(
728750
close_store_on_close,
729751
use_zarr_fill_value_as_mask,
730752
cache_members,
753+
coords_buffer_prototype,
731754
)
732755

733756
def __init__(
@@ -742,6 +765,7 @@ def __init__(
742765
close_store_on_close: bool = False,
743766
use_zarr_fill_value_as_mask=None,
744767
cache_members: bool = True,
768+
coords_buffer_prototype: Any | None = None,
745769
):
746770
self.zarr_group = zarr_group
747771
self._read_only = self.zarr_group.read_only
@@ -757,6 +781,14 @@ def __init__(
757781
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
758782
self._cache_members: bool = cache_members
759783
self._members: dict[str, ZarrArray | ZarrGroup] = {}
784+
if _zarr_v3() and coords_buffer_prototype is None:
785+
# Once zarr-v3 is required we can just have this as the default
786+
# https://github.com/zarr-developers/zarr-python/issues/2871
787+
# Use the public API once available
788+
from zarr.core.buffer.cpu import buffer_prototype
789+
790+
coords_buffer_prototype = buffer_prototype
791+
self._coords_buffer_prototype = coords_buffer_prototype
760792

761793
if self._cache_members:
762794
# initialize the cache
@@ -815,8 +847,15 @@ def ds(self):
815847

816848
def open_store_variable(self, name):
817849
zarr_array = self.members[name]
818-
is_coordinate = name in zarr_array.metadata.dimension_names
819-
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array, is_coordinate=is_coordinate))
850+
is_coordinate = _is_coordinate_variable(zarr_array, name)
851+
852+
data = indexing.LazilyIndexedArray(
853+
ZarrArrayWrapper(
854+
zarr_array,
855+
is_coordinate=is_coordinate,
856+
coords_buffer_prototype=self._coords_buffer_prototype,
857+
)
858+
)
820859
try_nczarr = self._mode == "r"
821860
dimensions, attributes = _get_zarr_dims_and_attrs(
822861
zarr_array, DIMENSION_KEY, try_nczarr
@@ -1339,6 +1378,7 @@ def open_zarr(
13391378
use_zarr_fill_value_as_mask=None,
13401379
chunked_array_type: str | None = None,
13411380
from_array_kwargs: dict[str, Any] | None = None,
1381+
coords_buffer_prototype: Any | None = None,
13421382
**kwargs,
13431383
):
13441384
"""Load and decode a dataset from a Zarr store.
@@ -1449,6 +1489,12 @@ def open_zarr(
14491489
chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
14501490
Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
14511491
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
1492+
coords_buffer_prototype : zarr.buffer.BufferPrototype, optional
1493+
The buffer prototype to use for loading coordinate arrays. Zarr offers control over
1494+
which device's memory buffers are read into. By default, xarray will always load
1495+
*coordinate* buffers into host (CPU) memory, regardless of the global zarr
1496+
configuration. To override this behavior, explicitly pass the buffer prototype
1497+
to use for coordinates here.
14521498
14531499
Returns
14541500
-------
@@ -1492,6 +1538,7 @@ def open_zarr(
14921538
"storage_options": storage_options,
14931539
"zarr_version": zarr_version,
14941540
"zarr_format": zarr_format,
1541+
"coords_buffer_prototype": coords_buffer_prototype,
14951542
}
14961543

14971544
ds = open_dataset(
@@ -1564,6 +1611,7 @@ def open_dataset(
15641611
engine=None,
15651612
use_zarr_fill_value_as_mask=None,
15661613
cache_members: bool = True,
1614+
coords_buffer_prototype: Any | None = None,
15671615
) -> Dataset:
15681616
filename_or_obj = _normalize_path(filename_or_obj)
15691617
if not store:
@@ -1580,6 +1628,7 @@ def open_dataset(
15801628
use_zarr_fill_value_as_mask=None,
15811629
zarr_format=zarr_format,
15821630
cache_members=cache_members,
1631+
coords_buffer_prototype=coords_buffer_prototype,
15831632
)
15841633

15851634
store_entrypoint = StoreBackendEntrypoint()
@@ -1615,6 +1664,7 @@ def open_datatree(
16151664
storage_options=None,
16161665
zarr_version=None,
16171666
zarr_format=None,
1667+
coords_buffer_prototype: Any | None = None,
16181668
) -> DataTree:
16191669
filename_or_obj = _normalize_path(filename_or_obj)
16201670
groups_dict = self.open_groups_as_dict(
@@ -1634,6 +1684,7 @@ def open_datatree(
16341684
storage_options=storage_options,
16351685
zarr_version=zarr_version,
16361686
zarr_format=zarr_format,
1687+
coords_buffer_prototype=coords_buffer_prototype,
16371688
)
16381689

16391690
return datatree_from_dict_with_io_cleanup(groups_dict)
@@ -1657,6 +1708,7 @@ def open_groups_as_dict(
16571708
storage_options=None,
16581709
zarr_version=None,
16591710
zarr_format=None,
1711+
coords_buffer_prototype: Any | None = None,
16601712
) -> dict[str, Dataset]:
16611713
from xarray.core.treenode import NodePath
16621714

@@ -1679,6 +1731,7 @@ def open_groups_as_dict(
16791731
storage_options=storage_options,
16801732
zarr_version=zarr_version,
16811733
zarr_format=zarr_format,
1734+
coords_buffer_prototype=coords_buffer_prototype,
16821735
)
16831736

16841737
groups_dict = {}

xarray/tests/test_backends.py

+31
Original file line numberDiff line numberDiff line change
@@ -3766,6 +3766,37 @@ def test_zarr_version_deprecated() -> None:
37663766
xr.open_zarr(store=store, zarr_version=2, zarr_format=3)
37673767

37683768

3769+
@requires_zarr
3770+
def test_coords_buffer_prototype() -> None:
3771+
pytest.importorskip("zarr", minversion="3")
3772+
3773+
from zarr.core.buffer import cpu
3774+
from zarr.core.buffer.core import BufferPrototype
3775+
3776+
counter = 0
3777+
3778+
class Buffer(cpu.Buffer):
3779+
def __init__(self, *args, **kwargs):
3780+
nonlocal counter
3781+
counter += 1
3782+
super().__init__(*args, **kwargs)
3783+
3784+
class NDBuffer(cpu.NDBuffer):
3785+
def __init__(self, *args, **kwargs):
3786+
nonlocal counter
3787+
counter += 1
3788+
super().__init__(*args, **kwargs)
3789+
3790+
prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)
3791+
3792+
ds = create_test_data()
3793+
store = KVStore()
3794+
ds.to_zarr(store=store, zarr_format=3)
3795+
3796+
xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype)
3797+
assert counter > 0
3798+
3799+
37693800
@requires_scipy
37703801
class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only):
37713802
engine: T_NetcdfEngine = "scipy"

0 commit comments

Comments
 (0)