diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 2e98a43f94..37a5b76bba 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -188,7 +188,6 @@ async def consolidate_metadata( group.store_path.store._check_writable() members_metadata = {k: v.metadata async for k, v in group.members(max_depth=None)} - # While consolidating, we want to be explicit about when child groups # are empty by inserting an empty dict for consolidated_metadata.metadata for k, v in members_metadata.items(): diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 915158cb5a..e0aad8b6ad 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1995,10 +1995,11 @@ def path(self) -> str: @property def name(self) -> str: + """Array name following h5py convention.""" return self._async_array.name @property - def basename(self) -> str | None: + def basename(self) -> str: """Final component of name.""" return self._async_array.basename diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index ebdc63364e..82970e4b7f 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -31,7 +31,7 @@ create_array, ) from zarr.core.attributes import Attributes -from zarr.core.buffer import default_buffer_prototype +from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import ( JSON, ZARR_JSON, @@ -662,6 +662,7 @@ async def getitem( """ store_path = self.store_path / key logger.debug("key=%s, store_path=%s", key, store_path) + metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: @@ -678,12 +679,9 @@ async def getitem( raise KeyError(key) else: zarr_json = json.loads(zarr_json_bytes.to_bytes()) - if zarr_json["node_type"] == "group": - return type(self).from_dict(store_path, zarr_json) - elif zarr_json["node_type"] == "array": - return AsyncArray.from_dict(store_path, zarr_json) - else: - raise ValueError(f"unexpected node_type: {zarr_json['node_type']}") + metadata = _build_metadata_v3(zarr_json) + return _build_node_v3(metadata, store_path) + elif self.metadata.zarr_format == 2: # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? # This guarantees that we will always make at least one extra request to the store @@ -698,21 +696,19 @@ async def getitem( # unpack the zarray, if this is None then we must be opening a group zarray = json.loads(zarray_bytes.to_bytes()) if zarray_bytes else None + zgroup = json.loads(zgroup_bytes.to_bytes()) if zgroup_bytes else None # unpack the zattrs, this can be None if no attrs were written zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} if zarray is not None: - # TODO: update this once the V2 array support is part of the primary array class - zarr_json = {**zarray, "attributes": zattrs} - return AsyncArray.from_dict(store_path, zarr_json) + metadata = _build_metadata_v2(zarray, zattrs) + return _build_node_v2(metadata=metadata, store_path=store_path) else: - zgroup = ( - json.loads(zgroup_bytes.to_bytes()) - if zgroup_bytes is not None - else {"zarr_format": self.metadata.zarr_format} - ) - zarr_json = {**zgroup, "attributes": zattrs} - return type(self).from_dict(store_path, zarr_json) + # this is just for mypy + if TYPE_CHECKING: + assert zgroup is not None + metadata = _build_metadata_v2(zgroup, zattrs) + return _build_node_v2(metadata=metadata, store_path=store_path) else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") @@ -1346,18 +1342,50 @@ async def members( """ if max_depth is not None and max_depth < 0: raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead") - async for item in self._members(max_depth=max_depth, current_depth=0): + async for item in self._members(max_depth=max_depth): yield item - async def _members( - self, max_depth: int | None, current_depth: int - ) -> AsyncGenerator[ + def _members_consolidated( + self, max_depth: int | None, prefix: str = "" + ) -> Generator[ tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None, ]: + consolidated_metadata = self.metadata.consolidated_metadata + + do_recursion = max_depth is None or max_depth > 0 + + # we kind of just want the top-level keys. + if consolidated_metadata is not None: + for key in consolidated_metadata.metadata: + obj = self._getitem_consolidated( + self.store_path, key, prefix=self.name + ) # Metadata -> Group/Array + key = f"{prefix}/{key}".lstrip("/") + yield key, obj + + if do_recursion and isinstance(obj, AsyncGroup): + if max_depth is None: + new_depth = None + else: + new_depth = max_depth - 1 + yield from obj._members_consolidated(new_depth, prefix=key) + + async def _members( + self, max_depth: int | None + ) -> AsyncGenerator[ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None + ]: + skip_keys: tuple[str, ...] + if self.metadata.zarr_format == 2: + skip_keys = (".zattrs", ".zgroup", ".zarray", ".zmetadata") + elif self.metadata.zarr_format == 3: + skip_keys = ("zarr.json",) + else: + raise ValueError(f"Unknown Zarr format: {self.metadata.zarr_format}") + if self.metadata.consolidated_metadata is not None: - # we should be able to do members without any additional I/O - members = self._members_consolidated(max_depth, current_depth) + members = self._members_consolidated(max_depth=max_depth) for member in members: yield member return @@ -1371,66 +1399,12 @@ async def _members( ) raise ValueError(msg) - # would be nice to make these special keys accessible programmatically, - # and scoped to specific zarr versions - # especially true for `.zmetadata` which is configurable - _skip_keys = ("zarr.json", ".zgroup", ".zattrs", ".zmetadata") - - # hmm lots of I/O and logic interleaved here. - # We *could* have an async gen over self.metadata.consolidated_metadata.metadata.keys() - # and plug in here. `getitem` will skip I/O. - # Kinda a shame to have all the asyncio task overhead though, when it isn't needed. - - async for key in self.store_path.store.list_dir(self.store_path.path): - if key in _skip_keys: - continue - try: - obj = await self.getitem(key) - yield (key, obj) - - if ( - ((max_depth is None) or (current_depth < max_depth)) - and hasattr(obj.metadata, "node_type") - and obj.metadata.node_type == "group" - ): - # the assert is just for mypy to know that `obj.metadata.node_type` - # implies an AsyncGroup, not an AsyncArray - assert isinstance(obj, AsyncGroup) - async for child_key, val in obj._members( - max_depth=max_depth, current_depth=current_depth + 1 - ): - yield f"{key}/{child_key}", val - except KeyError: - # keyerror is raised when `key` names an object (in the object storage sense), - # as opposed to a prefix, in the store under the prefix associated with this group - # in which case `key` cannot be the name of a sub-array or sub-group. - warnings.warn( - f"Object at {key} is not recognized as a component of a Zarr hierarchy.", - UserWarning, - stacklevel=1, - ) - - def _members_consolidated( - self, max_depth: int | None, current_depth: int, prefix: str = "" - ) -> Generator[ - tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], - None, - ]: - consolidated_metadata = self.metadata.consolidated_metadata - - # we kind of just want the top-level keys. - if consolidated_metadata is not None: - for key in consolidated_metadata.metadata: - obj = self._getitem_consolidated( - self.store_path, key, prefix=self.name - ) # Metadata -> Group/Array - key = f"{prefix}/{key}".lstrip("/") - yield key, obj - - if ((max_depth is None) or (current_depth < max_depth)) and isinstance( - obj, AsyncGroup - ): - yield from obj._members_consolidated(max_depth, current_depth + 1, prefix=key) + # enforce a concurrency limit by passing a semaphore to all the recursive functions + semaphore = asyncio.Semaphore(config.get("async.concurrency")) + async for member in _iter_members_deep( + self, max_depth=max_depth, skip_keys=skip_keys, semaphore=semaphore + ): + yield member async def keys(self) -> AsyncGenerator[str, None]: """Iterate over member names.""" @@ -2783,3 +2757,191 @@ def array( ) ) ) + + +async def _getitem_semaphore( + node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None +) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: + """ + Combine node.getitem with an optional semaphore. If the semaphore parameter is an + asyncio.Semaphore instance, then the getitem operation is performed inside an async context + manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked + without a context manager. + """ + if semaphore is not None: + async with semaphore: + return await node.getitem(key) + else: + return await node.getitem(key) + + +async def _iter_members( + node: AsyncGroup, + skip_keys: tuple[str, ...], + semaphore: asyncio.Semaphore | None, +) -> AsyncGenerator[ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None +]: + """ + Iterate over the arrays and groups contained in a group. + + Parameters + ---------- + node : AsyncGroup + The group to traverse. + skip_keys : tuple[str, ...] + A tuple of keys to skip when iterating over the possible members of the group. + semaphore : asyncio.Semaphore | None + An optional semaphore to use for concurrency control. + + Yields + ------ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup] + """ + + # retrieve keys from storage + keys = [key async for key in node.store.list_dir(node.path)] + keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) + + node_tasks = tuple( + asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key) + for key in keys_filtered + ) + + for fetched_node_coro in asyncio.as_completed(node_tasks): + try: + fetched_node = await fetched_node_coro + except KeyError as e: + # keyerror is raised when `key` names an object (in the object storage sense), + # as opposed to a prefix, in the store under the prefix associated with this group + # in which case `key` cannot be the name of a sub-array or sub-group. + warnings.warn( + f"Object at {e.args[0]} is not recognized as a component of a Zarr hierarchy.", + UserWarning, + stacklevel=1, + ) + continue + match fetched_node: + case AsyncArray() | AsyncGroup(): + yield fetched_node.basename, fetched_node + case _: + raise ValueError(f"Unexpected type: {type(fetched_node)}") + + +async def _iter_members_deep( + group: AsyncGroup, + *, + max_depth: int | None, + skip_keys: tuple[str, ...], + semaphore: asyncio.Semaphore | None = None, +) -> AsyncGenerator[ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup], None +]: + """ + Iterate over the arrays and groups contained in a group, and optionally the + arrays and groups contained in those groups. + + Parameters + ---------- + group : AsyncGroup + The group to traverse. + max_depth : int | None + The maximum depth of recursion. + skip_keys : tuple[str, ...] + A tuple of keys to skip when iterating over the possible members of the group. + semaphore : asyncio.Semaphore | None + An optional semaphore to use for concurrency control. + + Yields + ------ + tuple[str, AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup] + """ + + to_recurse = {} + do_recursion = max_depth is None or max_depth > 0 + + if max_depth is None: + new_depth = None + else: + new_depth = max_depth - 1 + async for name, node in _iter_members(group, skip_keys=skip_keys, semaphore=semaphore): + yield name, node + if isinstance(node, AsyncGroup) and do_recursion: + to_recurse[name] = _iter_members_deep( + node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore + ) + + for prefix, subgroup_iter in to_recurse.items(): + async for name, node in subgroup_iter: + key = f"{prefix}/{name}".lstrip("/") + yield key, node + + +def _resolve_metadata_v2( + blobs: tuple[str | bytes | bytearray, str | bytes | bytearray], +) -> ArrayV2Metadata | GroupMetadata: + zarr_metadata = json.loads(blobs[0]) + attrs = json.loads(blobs[1]) + if "shape" in zarr_metadata: + return ArrayV2Metadata.from_dict(zarr_metadata | {"attrs": attrs}) + else: + return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs}) + + +def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: + """ + Take a dict and convert it into the correct metadata type. + """ + if "node_type" not in zarr_json: + raise KeyError("missing `node_type` key in metadata document.") + match zarr_json: + case {"node_type": "array"}: + return ArrayV3Metadata.from_dict(zarr_json) + case {"node_type": "group"}: + return GroupMetadata.from_dict(zarr_json) + case _: + raise ValueError("invalid value for `node_type` key in metadata document") + + +def _build_metadata_v2( + zarr_json: dict[str, Any], attrs_json: dict[str, Any] +) -> ArrayV2Metadata | GroupMetadata: + """ + Take a dict and convert it into the correct metadata type. + """ + match zarr_json: + case {"shape": _}: + return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json}) + case _: + return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) + + +def _build_node_v3( + metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath +) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Take a metadata object and return a node (AsyncArray or AsyncGroup). + """ + match metadata: + case ArrayV3Metadata(): + return AsyncArray(metadata, store_path=store_path) + case GroupMetadata(): + return AsyncGroup(metadata, store_path=store_path) + case _: + raise ValueError(f"Unexpected metadata type: {type(metadata)}") + + +def _build_node_v2( + metadata: ArrayV2Metadata | GroupMetadata, store_path: StorePath +) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: + """ + Take a metadata object and return a node (AsyncArray or AsyncGroup). + """ + + match metadata: + case ArrayV2Metadata(): + return AsyncArray(metadata, store_path=store_path) + case GroupMetadata(): + return AsyncGroup(metadata, store_path=store_path) + case _: + raise ValueError(f"Unexpected metadata type: {type(metadata)}") diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index 450913e9d3..45ddeef40c 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -11,7 +11,7 @@ from zarr.storage._wrapper import WrapperStore if TYPE_CHECKING: - from collections.abc import AsyncIterator, Generator, Iterable + from collections.abc import AsyncGenerator, Generator, Iterable from zarr.abc.store import ByteRangeRequest from zarr.core.buffer import Buffer, BufferPrototype @@ -205,19 +205,19 @@ async def set_partial_values( with self.log(keys): return await self._store.set_partial_values(key_start_values=key_start_values) - async def list(self) -> AsyncIterator[str]: + async def list(self) -> AsyncGenerator[str, None]: # docstring inherited with self.log(): async for key in self._store.list(): yield key - async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: # docstring inherited with self.log(prefix): async for key in self._store.list_prefix(prefix=prefix): yield key - async def list_dir(self, prefix: str) -> AsyncIterator[str]: + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: # docstring inherited with self.log(prefix): async for key in self._store.list_dir(prefix=prefix): diff --git a/tests/test_group.py b/tests/test_group.py index 19a9f9c9bb..c2a5f751f3 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -3,6 +3,7 @@ import contextlib import operator import pickle +import time import warnings from typing import TYPE_CHECKING, Any, Literal @@ -22,6 +23,7 @@ from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore, make_store_path +from zarr.testing.store import LatencyStore from .conftest import parse_store @@ -1440,6 +1442,71 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None g1["0/0"] +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_members_performance(store: MemoryStore) -> None: + """ + Test that the execution time of Group.members is less than the number of members times the + latency for accessing each member. + """ + get_latency = 0.1 + + # use the input store to create some groups + group_create = zarr.group(store=store) + num_groups = 10 + + # Create some groups + for i in range(num_groups): + group_create.create_group(f"group{i}") + + latency_store = LatencyStore(store, get_latency=get_latency) + # create a group with some latency on get operations + group_read = zarr.group(store=latency_store) + + # check how long it takes to iterate over the groups + # if .members is sensitive to IO latency, + # this should take (num_groups * get_latency) seconds + # otherwise, it should take only marginally more than get_latency seconds + start = time.time() + _ = group_read.members() + elapsed = time.time() - start + + assert elapsed < (num_groups * get_latency) + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_members_concurrency_limit(store: MemoryStore) -> None: + """ + Test that the execution time of Group.members can be constrained by the async concurrency + configuration setting. + """ + get_latency = 0.02 + + # use the input store to create some groups + group_create = zarr.group(store=store) + num_groups = 10 + + # Create some groups + for i in range(num_groups): + group_create.create_group(f"group{i}") + + latency_store = LatencyStore(store, get_latency=get_latency) + # create a group with some latency on get operations + group_read = zarr.group(store=latency_store) + + # check how long it takes to iterate over the groups + # if .members is sensitive to IO latency, + # this should take (num_groups * get_latency) seconds + # otherwise, it should take only marginally more than get_latency seconds + from zarr.core.config import config + + with config.set({"async.concurrency": 1}): + start = time.time() + _ = group_read.members() + elapsed = time.time() - start + + assert elapsed > num_groups * get_latency + + @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) def test_deprecated_compressor(store: Store) -> None: g = zarr.group(store=store, zarr_format=2)