Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions webknossos/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import webknossos as wk
from webknossos.client._upload_dataset import _cached_get_upload_datastore
from webknossos.client.context import _clear_all_context_caches
from webknossos.webknossos.dataset._array import _clear_tensorstore_context

from .constants import TESTDATA_DIR, TESTOUTPUT_DIR

Expand Down Expand Up @@ -101,6 +102,7 @@ def clear_testoutput() -> Generator:
def clear_caches() -> Generator:
_clear_all_context_caches()
_cached_get_upload_datastore.cache_clear()
_clear_tensorstore_context()
yield


Expand Down
33 changes: 33 additions & 0 deletions webknossos/tests/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Iterator
from pathlib import Path
from typing import cast
from uuid import uuid4

import numpy as np
import pytest
Expand Down Expand Up @@ -3345,6 +3346,38 @@ def test_fs_copy_dataset_with_attachments(input_path: Path, output_path: Path) -
assert (new_ds_path / "segmentation" / "meshes" / "meshfile" / "zarr.json").exists()


@pytest.mark.parametrize("input_protocol", ["memory", "file"])
@pytest.mark.parametrize("output_path", OUTPUT_PATHS)
def test_fs_copy_dataset_memory(input_protocol: str, output_path: Path) -> None:
if input_protocol == "file":
ds_path = UPath(TESTOUTPUT_DIR / "test_dataset")
else:
ds_path = UPath(f"memory:///test_dataset-{uuid4()}")
new_ds_path = prepare_dataset_path(DEFAULT_DATA_FORMAT, output_path, "copied")

ds = Dataset(ds_path, (1, 1, 1))
# Add segmentation layer and meshfile
seg_layer = ds.add_layer(
"segmentation",
SEGMENTATION_CATEGORY,
largest_segment_id=999,
bounding_box=BoundingBox((0, 0, 0), (10, 10, 10)),
).as_segmentation_layer()
seg_mag = seg_layer.add_mag(1)
seg_mag.write(data=np.ones((1, 10, 10, 10), dtype=np.uint8))

# Copy
copy_ds = ds.fs_copy_dataset(new_ds_path)

assert (new_ds_path / "segmentation" / "1" / "zarr.json").exists()
np.testing.assert_array_equal(
copy_ds.get_segmentation_layer("segmentation")
.get_mag("1")
.read(absolute_offset=(0, 0, 0), size=(10, 10, 10)),
np.ones((1, 10, 10, 10), dtype=np.uint8),
)


def test_wkw_copy_to_remote_dataset() -> None:
ds_path = prepare_dataset_path(DataFormat.WKW, REMOTE_TESTOUTPUT_DIR, "copied")
wkw_ds = Dataset.open(TESTDATA_DIR / "simple_wkw_dataset")
Expand Down
5 changes: 5 additions & 0 deletions webknossos/webknossos/dataset/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class ArrayException(Exception):
ReturnType = TypeVar("ReturnType")


def _clear_tensorstore_context() -> None:
global TS_CONTEXT
TS_CONTEXT = tensorstore.Context()


def call_with_retries(
fn: Callable[[], ReturnType],
num_retries: int = DEFAULT_NUM_RETRIES,
Expand Down
14 changes: 13 additions & 1 deletion webknossos/webknossos/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

from ..utils import (
copytree,
copytree_tensorstore,
count_defined_values,
get_executor_for_args,
infer_metadata_type,
Expand Down Expand Up @@ -2701,14 +2702,25 @@ def fs_copy_dataset(
if layers_to_ignore is not None and layer.name in layers_to_ignore:
continue
new_layer = new_dataset.add_layer_like(layer, layer.name)

for mag_view in layer.mags.values():
new_mag = new_layer.add_mag(
mag_view.mag,
chunk_shape=mag_view.info.chunk_shape,
shard_shape=mag_view.info.shard_shape,
compress=mag_view.info.compression_mode,
)
copytree(mag_view.path, new_mag.path)
if (
isinstance(mag_view.path, UPath)
and mag_view.path.protocol == "memory"
) or (
isinstance(new_mag.path, UPath)
and new_mag.path.protocol == "memory"
):
# We need to special-case memory paths because tensorstore has its own memory namespace
copytree_tensorstore(mag_view.path, new_mag.path)
else:
copytree(mag_view.path, new_mag.path)
if isinstance(layer, SegmentationLayer) and isinstance(
new_layer, SegmentationLayer
):
Expand Down
37 changes: 37 additions & 0 deletions webknossos/webknossos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,43 @@ def _copy(args: tuple[Path, Path, tuple[str, ...]]) -> None:
pass


def copytree_tensorstore(
in_path: Path,
out_path: Path,
*,
threads: int | None = 10,
progress_desc: str | None = None,
) -> None:
from tensorstore import KvStore

from .dataset._array import TS_CONTEXT, TensorStoreArray

in_kv = KvStore.open(
TensorStoreArray._make_kvstore(in_path), context=TS_CONTEXT
).result()
out_kv = KvStore.open(
TensorStoreArray._make_kvstore(out_path), context=TS_CONTEXT
).result()

def _copy(key: bytes) -> None:
data = in_kv.read(key).result().value
out_kv.write(key, data).result()

keys_to_copy: list[bytes] = [
key for key in in_kv.list().result() if key.startswith(b"/")
]
with ThreadPool(threads) as pool:
iterator = pool.imap_unordered(_copy, keys_to_copy)

if progress_desc:
with get_rich_progress() as progress:
task = progress.add_task(progress_desc, total=len(keys_to_copy))
for _ in iterator:
progress.update(task, advance=1)
for _ in iterator:
pass


def movetree(in_path: Path, out_path: Path) -> None:
move(in_path, out_path)

Expand Down