@@ -179,15 +179,31 @@ def encode_zarr_attr_value(value):
179
179
return encoded
180
180
181
181
182
+ def _is_coordinate_variable (zarr_array , name ):
183
+ if _zarr_v3 ():
184
+ if zarr_array .metadata .zarr_format == 2 :
185
+ is_coordinate = name in zarr_array .metadata .attributes .get (
186
+ "_ARRAY_DIMENSIONS" , []
187
+ )
188
+ else :
189
+ is_coordinate = name in (zarr_array .metadata .dimension_names or [])
190
+ else :
191
+ is_coordinate = name in zarr_array .attrs .get ("_ARRAY_DIMENSIONS" , [])
192
+ return is_coordinate
193
+
194
+
182
195
class ZarrArrayWrapper (BackendArray ):
183
- __slots__ = ("_array" , "dtype " , "shape " , "is_coordinate" )
196
+ __slots__ = ("_array" , "coords_buffer_prototype " , "dtype " , "is_coordinate" , "shape " )
184
197
185
- def __init__ (self , zarr_array , is_coordinate : bool ):
198
+ def __init__ (
199
+ self , zarr_array , is_coordinate : bool , coords_buffer_prototype : Any | None
200
+ ):
186
201
# some callers attempt to evaluate an array if an `array` property exists on the object.
187
202
# we prefix with _ to avoid this inference.
188
203
self ._array = zarr_array
189
204
self .shape = self ._array .shape
190
205
self .is_coordinate = is_coordinate
206
+ self .coords_buffer_prototype = coords_buffer_prototype
191
207
192
208
# preserve vlen string object dtype (GH 7328)
193
209
if (
@@ -211,12 +227,14 @@ def _vindex(self, key):
211
227
return self ._array .vindex [key ]
212
228
213
229
def _getitem (self , key ):
214
- from zarr .core .buffer .cpu import buffer_prototype
215
- if self .is_coordinate :
216
- prototype = buffer_prototype
217
- else :
218
- prototype = None
219
- return self ._array .get_basic_selection (key , prototype = prototype )
230
+ kwargs = {}
231
+ if _zarr_v3 ():
232
+ if self .is_coordinate :
233
+ prototype = self .coords_buffer_prototype
234
+ else :
235
+ prototype = None
236
+ kwargs ["prototype" ] = prototype
237
+ return self ._array .get_basic_selection (key , ** kwargs )
220
238
221
239
def __getitem__ (self , key ):
222
240
array = self ._array
@@ -611,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore):
611
629
"_cache_members" ,
612
630
"_close_store_on_close" ,
613
631
"_consolidate_on_close" ,
632
+ "_coords_buffer_prototype" ,
614
633
"_group" ,
615
634
"_members" ,
616
635
"_mode" ,
@@ -642,6 +661,7 @@ def open_store(
642
661
use_zarr_fill_value_as_mask = None ,
643
662
write_empty : bool | None = None ,
644
663
cache_members : bool = True ,
664
+ coords_buffer_prototype : Any | None = None ,
645
665
):
646
666
(
647
667
zarr_group ,
@@ -674,6 +694,7 @@ def open_store(
674
694
close_store_on_close ,
675
695
use_zarr_fill_value_as_mask ,
676
696
cache_members = cache_members ,
697
+ coords_buffer_prototype = coords_buffer_prototype ,
677
698
)
678
699
for group in group_paths
679
700
}
@@ -697,6 +718,7 @@ def open_group(
697
718
use_zarr_fill_value_as_mask = None ,
698
719
write_empty : bool | None = None ,
699
720
cache_members : bool = True ,
721
+ coords_buffer_prototype : Any | None = None ,
700
722
):
701
723
(
702
724
zarr_group ,
@@ -728,6 +750,7 @@ def open_group(
728
750
close_store_on_close ,
729
751
use_zarr_fill_value_as_mask ,
730
752
cache_members ,
753
+ coords_buffer_prototype ,
731
754
)
732
755
733
756
def __init__ (
@@ -742,6 +765,7 @@ def __init__(
742
765
close_store_on_close : bool = False ,
743
766
use_zarr_fill_value_as_mask = None ,
744
767
cache_members : bool = True ,
768
+ coords_buffer_prototype : Any | None = None ,
745
769
):
746
770
self .zarr_group = zarr_group
747
771
self ._read_only = self .zarr_group .read_only
@@ -757,6 +781,14 @@ def __init__(
757
781
self ._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
758
782
self ._cache_members : bool = cache_members
759
783
self ._members : dict [str , ZarrArray | ZarrGroup ] = {}
784
+ if _zarr_v3 () and coords_buffer_prototype is None :
785
+ # Once zarr-v3 is required we can just have this as the default
786
+ # https://github.com/zarr-developers/zarr-python/issues/2871
787
+ # Use the public API once available
788
+ from zarr .core .buffer .cpu import buffer_prototype
789
+
790
+ coords_buffer_prototype = buffer_prototype
791
+ self ._coords_buffer_prototype = coords_buffer_prototype
760
792
761
793
if self ._cache_members :
762
794
# initialize the cache
@@ -815,8 +847,15 @@ def ds(self):
815
847
816
848
def open_store_variable (self , name ):
817
849
zarr_array = self .members [name ]
818
- is_coordinate = name in zarr_array .metadata .dimension_names
819
- data = indexing .LazilyIndexedArray (ZarrArrayWrapper (zarr_array , is_coordinate = is_coordinate ))
850
+ is_coordinate = _is_coordinate_variable (zarr_array , name )
851
+
852
+ data = indexing .LazilyIndexedArray (
853
+ ZarrArrayWrapper (
854
+ zarr_array ,
855
+ is_coordinate = is_coordinate ,
856
+ coords_buffer_prototype = self ._coords_buffer_prototype ,
857
+ )
858
+ )
820
859
try_nczarr = self ._mode == "r"
821
860
dimensions , attributes = _get_zarr_dims_and_attrs (
822
861
zarr_array , DIMENSION_KEY , try_nczarr
@@ -1339,6 +1378,7 @@ def open_zarr(
1339
1378
use_zarr_fill_value_as_mask = None ,
1340
1379
chunked_array_type : str | None = None ,
1341
1380
from_array_kwargs : dict [str , Any ] | None = None ,
1381
+ coords_buffer_prototype : Any | None = None ,
1342
1382
** kwargs ,
1343
1383
):
1344
1384
"""Load and decode a dataset from a Zarr store.
@@ -1449,6 +1489,12 @@ def open_zarr(
1449
1489
chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
1450
1490
Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
1451
1491
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
1492
+ coords_buffer_prototype : zarr.buffer.BufferPrototype, optional
1493
+ The buffer prototype to use for loading coordinate arrays. Zarr offers control over
1494
+ which device's memory buffers are read into. By default, xarray will always load
1495
+ *coordinate* buffers into host (CPU) memory, regardless of the global zarr
1496
+ configuration. To override this behavior, explicitly pass the buffer prototype
1497
+ to use for coordinates here.
1452
1498
1453
1499
Returns
1454
1500
-------
@@ -1492,6 +1538,7 @@ def open_zarr(
1492
1538
"storage_options" : storage_options ,
1493
1539
"zarr_version" : zarr_version ,
1494
1540
"zarr_format" : zarr_format ,
1541
+ "coords_buffer_prototype" : coords_buffer_prototype ,
1495
1542
}
1496
1543
1497
1544
ds = open_dataset (
@@ -1564,6 +1611,7 @@ def open_dataset(
1564
1611
engine = None ,
1565
1612
use_zarr_fill_value_as_mask = None ,
1566
1613
cache_members : bool = True ,
1614
+ coords_buffer_prototype : Any | None = None ,
1567
1615
) -> Dataset :
1568
1616
filename_or_obj = _normalize_path (filename_or_obj )
1569
1617
if not store :
@@ -1580,6 +1628,7 @@ def open_dataset(
1580
1628
use_zarr_fill_value_as_mask = None ,
1581
1629
zarr_format = zarr_format ,
1582
1630
cache_members = cache_members ,
1631
+ coords_buffer_prototype = coords_buffer_prototype ,
1583
1632
)
1584
1633
1585
1634
store_entrypoint = StoreBackendEntrypoint ()
@@ -1615,6 +1664,7 @@ def open_datatree(
1615
1664
storage_options = None ,
1616
1665
zarr_version = None ,
1617
1666
zarr_format = None ,
1667
+ coords_buffer_prototype : Any | None = None ,
1618
1668
) -> DataTree :
1619
1669
filename_or_obj = _normalize_path (filename_or_obj )
1620
1670
groups_dict = self .open_groups_as_dict (
@@ -1634,6 +1684,7 @@ def open_datatree(
1634
1684
storage_options = storage_options ,
1635
1685
zarr_version = zarr_version ,
1636
1686
zarr_format = zarr_format ,
1687
+ coords_buffer_prototype = coords_buffer_prototype ,
1637
1688
)
1638
1689
1639
1690
return datatree_from_dict_with_io_cleanup (groups_dict )
@@ -1657,6 +1708,7 @@ def open_groups_as_dict(
1657
1708
storage_options = None ,
1658
1709
zarr_version = None ,
1659
1710
zarr_format = None ,
1711
+ coords_buffer_prototype : Any | None = None ,
1660
1712
) -> dict [str , Dataset ]:
1661
1713
from xarray .core .treenode import NodePath
1662
1714
@@ -1679,6 +1731,7 @@ def open_groups_as_dict(
1679
1731
storage_options = storage_options ,
1680
1732
zarr_version = zarr_version ,
1681
1733
zarr_format = zarr_format ,
1734
+ coords_buffer_prototype = coords_buffer_prototype ,
1682
1735
)
1683
1736
1684
1737
groups_dict = {}
0 commit comments