Skip to content

Commit 5e26b5e

Browse files
committed
add _open_zarr_store
1 parent d97a1d2 commit 5e26b5e

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

src/spatialdata/_core/spatialdata.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
)
3232
from spatialdata._logging import logger
3333
from spatialdata._types import ArrayLike, Raster_T
34-
from spatialdata._utils import (
35-
_deprecation_alias,
36-
_error_message_add_element,
37-
)
34+
from spatialdata._utils import _deprecation_alias, _error_message_add_element
3835
from spatialdata.models import (
3936
Image2DModel,
4037
Image3DModel,
@@ -600,7 +597,7 @@ def path(self, value: Path | None) -> None:
600597
)
601598

602599
def _get_groups_for_element(
603-
self, zarr_path: Path, element_type: str, element_name: str
600+
self, zarr_path: UPath, element_type: str, element_name: str
604601
) -> tuple[zarr.Group, zarr.Group, zarr.Group]:
605602
"""
606603
Get the Zarr groups for the root, element_type and element for a specific element.
@@ -620,9 +617,9 @@ def _get_groups_for_element(
620617
-------
621618
either the existing Zarr subgroup or a new one.
622619
"""
623-
if not isinstance(zarr_path, Path):
624-
raise ValueError("zarr_path should be a Path object")
625-
store = parse_url(zarr_path, mode="r+").store
620+
from spatialdata._io._utils import _open_zarr_store
621+
622+
store = _open_zarr_store(zarr_path, mode="r+")
626623
root = zarr.group(store=store)
627624
if element_type not in ["images", "labels", "points", "polygons", "shapes", "tables"]:
628625
raise ValueError(f"Unknown element type {element_type}")
@@ -1375,7 +1372,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None:
13751372
self.delete_element_from_disk(name)
13761373
return
13771374

1378-
from spatialdata._io._utils import _backed_elements_contained_in_path
1375+
from spatialdata._io._utils import _backed_elements_contained_in_path, _open_zarr_store
13791376

13801377
if self.path is None:
13811378
raise ValueError("The SpatialData object is not backed by a Zarr store.")
@@ -1416,7 +1413,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None:
14161413
)
14171414

14181415
# delete the element
1419-
store = parse_url(self.path, mode="r+").store
1416+
store = _open_zarr_store(self.path)
14201417
root = zarr.group(store=store)
14211418
root[element_type].pop(element_name)
14221419
store.close()
@@ -1437,7 +1434,9 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem
14371434
)
14381435

14391436
def write_consolidated_metadata(self) -> None:
1440-
store = parse_url(self.path, mode="r+").store
1437+
from spatialdata._io._utils import _open_zarr_store
1438+
1439+
store = _open_zarr_store(self.path)
14411440
# consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written
14421441
# instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121
14431442
zarr.consolidate_metadata(store, metadata_key=".zmetadata")
@@ -1574,15 +1573,11 @@ def write_transformations(self, element_name: str | None = None) -> None:
15741573
)
15751574
axes = get_axes_names(element)
15761575
if isinstance(element, DataArray | DataTree):
1577-
from spatialdata._io._utils import (
1578-
overwrite_coordinate_transformations_raster,
1579-
)
1576+
from spatialdata._io._utils import overwrite_coordinate_transformations_raster
15801577

15811578
overwrite_coordinate_transformations_raster(group=element_group, axes=axes, transformations=transformations)
15821579
elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData):
1583-
from spatialdata._io._utils import (
1584-
overwrite_coordinate_transformations_non_raster,
1585-
)
1580+
from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster
15861581

15871582
overwrite_coordinate_transformations_non_raster(
15881583
group=element_group, axes=axes, transformations=transformations

src/spatialdata/_io/_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from dask.array import Array as DaskArray
1818
from dask.dataframe import DataFrame as DaskDataFrame
1919
from geopandas import GeoDataFrame
20+
from upath import UPath
21+
from upath.implementations.local import PosixUPath, WindowsUPath
2022
from xarray import DataArray, DataTree
23+
from zarr.storage import FSStore
2124

2225
from spatialdata._core.spatialdata import SpatialData
2326
from spatialdata._utils import get_pyramid_levels
@@ -28,10 +31,7 @@
2831
_validate_mapping_to_coordinate_system_type,
2932
)
3033
from spatialdata.transformations.ngff.ngff_transformations import NgffBaseTransformation
31-
from spatialdata.transformations.transformations import (
32-
BaseTransformation,
33-
_get_current_output_axes,
34-
)
34+
from spatialdata.transformations.transformations import BaseTransformation, _get_current_output_axes
3535

3636

3737
# suppress logger debug from ome_zarr with context manager
@@ -383,3 +383,12 @@ def save_transformations(sdata: SpatialData) -> None:
383383
stacklevel=2,
384384
)
385385
sdata.write_transformations()
386+
387+
388+
def _open_zarr_store(path: str | UPath, **kwargs) -> zarr.storage.BaseStore:
389+
if isinstance(path, str | Path):
390+
path = UPath(path)
391+
if isinstance(path, PosixUPath | WindowsUPath):
392+
return zarr.storage.DirectoryStore(path.path)
393+
else:
394+
return FSStore(path.path, fs=path.fs, **kwargs)

tests/io/test_remote_mock.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ def test_local_sdata_remote_image(self, upath: UPath, images: SpatialData) -> No
263263
local_sdata = SpatialData.read(sdata_path) # noqa: F841
264264
remote_path = upath / "full_sdata.zarr" # noqa: F841
265265

266+
remote_sdata = SpatialData.read(remote_path)
267+
assert_spatial_data_objects_are_identical(local_sdata, remote_sdata)
268+
266269
# TODO: read a single remote image from the S3 data and add it to the local SpatialData object
267270
# for a in remote_path.glob('**/*'):
268271
# print(a)

0 commit comments

Comments
 (0)