diff --git a/webknossos/tests/conftest.py b/webknossos/tests/conftest.py index 88c9b7ba6..f186916ef 100644 --- a/webknossos/tests/conftest.py +++ b/webknossos/tests/conftest.py @@ -15,6 +15,7 @@ import webknossos as wk from webknossos.client._upload_dataset import _cached_get_upload_datastore from webknossos.client.context import _clear_all_context_caches +from webknossos.webknossos.dataset._array import _clear_tensorstore_context from .constants import TESTDATA_DIR, TESTOUTPUT_DIR @@ -101,6 +102,7 @@ def clear_testoutput() -> Generator: def clear_caches() -> Generator: _clear_all_context_caches() _cached_get_upload_datastore.cache_clear() + _clear_tensorstore_context() yield diff --git a/webknossos/tests/dataset/test_dataset.py b/webknossos/tests/dataset/test_dataset.py index fb7325eb9..7f1d0614a 100644 --- a/webknossos/tests/dataset/test_dataset.py +++ b/webknossos/tests/dataset/test_dataset.py @@ -5,6 +5,7 @@ from collections.abc import Iterator from pathlib import Path from typing import cast +from uuid import uuid4 import numpy as np import pytest @@ -3345,6 +3346,38 @@ def test_fs_copy_dataset_with_attachments(input_path: Path, output_path: Path) - assert (new_ds_path / "segmentation" / "meshes" / "meshfile" / "zarr.json").exists() +@pytest.mark.parametrize("input_protocol", ["memory", "file"]) +@pytest.mark.parametrize("output_path", OUTPUT_PATHS) +def test_fs_copy_dataset_memory(input_protocol: str, output_path: Path) -> None: + if input_protocol == "file": + ds_path = UPath(TESTOUTPUT_DIR / "test_dataset") + else: + ds_path = UPath(f"memory:///test_dataset-{uuid4()}") + new_ds_path = prepare_dataset_path(DEFAULT_DATA_FORMAT, output_path, "copied") + + ds = Dataset(ds_path, (1, 1, 1)) + # Add segmentation layer and meshfile + seg_layer = ds.add_layer( + "segmentation", + SEGMENTATION_CATEGORY, + largest_segment_id=999, + bounding_box=BoundingBox((0, 0, 0), (10, 10, 10)), + ).as_segmentation_layer() + seg_mag = seg_layer.add_mag(1) + seg_mag.write(data=np.ones((1, 10, 10, 10), dtype=np.uint8)) + + # Copy + copy_ds = ds.fs_copy_dataset(new_ds_path) + + assert (new_ds_path / "segmentation" / "1" / "zarr.json").exists() + np.testing.assert_array_equal( + copy_ds.get_segmentation_layer("segmentation") + .get_mag("1") + .read(absolute_offset=(0, 0, 0), size=(10, 10, 10)), + np.ones((1, 10, 10, 10), dtype=np.uint8), + ) + + def test_wkw_copy_to_remote_dataset() -> None: ds_path = prepare_dataset_path(DataFormat.WKW, REMOTE_TESTOUTPUT_DIR, "copied") wkw_ds = Dataset.open(TESTDATA_DIR / "simple_wkw_dataset") diff --git a/webknossos/webknossos/dataset/_array.py b/webknossos/webknossos/dataset/_array.py index d07df4124..4a3cd875d 100644 --- a/webknossos/webknossos/dataset/_array.py +++ b/webknossos/webknossos/dataset/_array.py @@ -43,6 +43,11 @@ class ArrayException(Exception): ReturnType = TypeVar("ReturnType") +def _clear_tensorstore_context() -> None: + global TS_CONTEXT + TS_CONTEXT = tensorstore.Context() + + def call_with_retries( fn: Callable[[], ReturnType], num_retries: int = DEFAULT_NUM_RETRIES, diff --git a/webknossos/webknossos/dataset/dataset.py b/webknossos/webknossos/dataset/dataset.py index ecd42e396..5e31a456e 100644 --- a/webknossos/webknossos/dataset/dataset.py +++ b/webknossos/webknossos/dataset/dataset.py @@ -65,6 +65,7 @@ from ..utils import ( copytree, + copytree_tensorstore, count_defined_values, get_executor_for_args, infer_metadata_type, @@ -2701,6 +2702,7 @@ def fs_copy_dataset( if layers_to_ignore is not None and layer.name in layers_to_ignore: continue new_layer = new_dataset.add_layer_like(layer, layer.name) + for mag_view in layer.mags.values(): new_mag = new_layer.add_mag( mag_view.mag, @@ -2708,7 +2710,17 @@ def fs_copy_dataset( shard_shape=mag_view.info.shard_shape, compress=mag_view.info.compression_mode, ) - copytree(mag_view.path, new_mag.path) + if ( + isinstance(mag_view.path, UPath) + and mag_view.path.protocol == "memory" + ) or ( + isinstance(new_mag.path, UPath) + and new_mag.path.protocol == "memory" + ): + # We need to special-case memory paths because tensorstore has its own memory namespace + copytree_tensorstore(mag_view.path, new_mag.path) + else: + copytree(mag_view.path, new_mag.path) if isinstance(layer, SegmentationLayer) and isinstance( new_layer, SegmentationLayer ): diff --git a/webknossos/webknossos/utils.py b/webknossos/webknossos/utils.py index 6e3884127..fd609a327 100644 --- a/webknossos/webknossos/utils.py +++ b/webknossos/webknossos/utils.py @@ -347,6 +347,43 @@ def _copy(args: tuple[Path, Path, tuple[str, ...]]) -> None: pass +def copytree_tensorstore( + in_path: Path, + out_path: Path, + *, + threads: int | None = 10, + progress_desc: str | None = None, +) -> None: + from tensorstore import KvStore + + from .dataset._array import TS_CONTEXT, TensorStoreArray + + in_kv = KvStore.open( + TensorStoreArray._make_kvstore(in_path), context=TS_CONTEXT + ).result() + out_kv = KvStore.open( + TensorStoreArray._make_kvstore(out_path), context=TS_CONTEXT + ).result() + + def _copy(key: bytes) -> None: + data = in_kv.read(key).result().value + out_kv.write(key, data).result() + + keys_to_copy: list[bytes] = [ + key for key in in_kv.list().result() if key.startswith(b"/") + ] + with ThreadPool(threads) as pool: + iterator = pool.imap_unordered(_copy, keys_to_copy) + + if progress_desc: + with get_rich_progress() as progress: + task = progress.add_task(progress_desc, total=len(keys_to_copy)) + for _ in iterator: + progress.update(task, advance=1) + for _ in iterator: + pass + + def movetree(in_path: Path, out_path: Path) -> None: move(in_path, out_path)