Skip to content

Commit de94764

Browse files
K-Meechd-v-b
andauthored
Prevent creation of arrays/groups under a parent array (#3407)
* restrict arrays as parents of other arrays * use common save_metadata function for groups and arrays * add test for creation under a parent array * fix failing tests * fix failing doctest * remove dependency on AsyncGroup * document changes --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent 62551c7 commit de94764

File tree

8 files changed

+151
-88
lines changed

8 files changed

+151
-88
lines changed

changes/2582.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Prevents creation of groups (.create_group) or arrays (.create_array) as children
2+
of an existing array.

docs/user-guide/arrays.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,11 +567,12 @@ Any combination of integer and slice can be used for block indexing::
567567
>>>
568568
>>> root = zarr.create_group('data/example-19.zarr')
569569
>>> foo = root.create_array(name='foo', shape=(1000, 100), chunks=(10, 10), dtype='float32')
570-
>>> bar = root.create_array(name='foo/bar', shape=(100,), dtype='int32')
570+
>>> bar = root.create_array(name='bar', shape=(100,), dtype='int32')
571571
>>> foo[:, :] = np.random.random((1000, 100))
572572
>>> bar[:] = np.arange(100)
573573
>>> root.tree()
574574
/
575+
├── bar (100,) int32
575576
└── foo (1000, 100) float32
576577
<BLANKLINE>
577578

src/zarr/core/array.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import zarr
2626
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
2727
from zarr.abc.numcodec import Numcodec, _is_numcodec
28-
from zarr.abc.store import Store, set_or_delete
2928
from zarr.codecs._v2 import V2Codec
3029
from zarr.codecs.bytes import BytesCodec
3130
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
@@ -110,6 +109,7 @@
110109
ArrayV3MetadataDict,
111110
T_ArrayMetadata,
112111
)
112+
from zarr.core.metadata.io import save_metadata
113113
from zarr.core.metadata.v2 import (
114114
CompressorLikev2,
115115
get_object_codec_id,
@@ -140,9 +140,9 @@
140140
import numpy.typing as npt
141141

142142
from zarr.abc.codec import CodecPipeline
143+
from zarr.abc.store import Store
143144
from zarr.codecs.sharding import ShardingCodecIndexLocation
144145
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar
145-
from zarr.core.group import AsyncGroup
146146
from zarr.storage import StoreLike
147147

148148

@@ -1639,24 +1639,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
16391639
"""
16401640
Asynchronously save the array metadata.
16411641
"""
1642-
to_save = metadata.to_buffer_dict(cpu_buffer_prototype)
1643-
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
1644-
1645-
if ensure_parents:
1646-
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
1647-
parents = _build_parents(self)
1648-
1649-
for parent in parents:
1650-
awaitables.extend(
1651-
[
1652-
(parent.store_path / key).set_if_not_exists(value)
1653-
for key, value in parent.metadata.to_buffer_dict(
1654-
cpu_buffer_prototype
1655-
).items()
1656-
]
1657-
)
1658-
1659-
await gather(*awaitables)
1642+
await save_metadata(self.store_path, metadata, ensure_parents=ensure_parents)
16601643

16611644
async def _set_selection(
16621645
self,
@@ -4121,37 +4104,6 @@ async def _shards_initialized(
41214104
)
41224105

41234106

4124-
def _build_parents(
4125-
node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup,
4126-
) -> list[AsyncGroup]:
4127-
from zarr.core.group import AsyncGroup, GroupMetadata
4128-
4129-
store = node.store_path.store
4130-
path = node.store_path.path
4131-
if not path:
4132-
return []
4133-
4134-
required_parts = path.split("/")[:-1]
4135-
parents = [
4136-
# the root group
4137-
AsyncGroup(
4138-
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
4139-
store_path=StorePath(store=store, path=""),
4140-
)
4141-
]
4142-
4143-
for i, part in enumerate(required_parts):
4144-
p = "/".join(required_parts[:i] + [part])
4145-
parents.append(
4146-
AsyncGroup(
4147-
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
4148-
store_path=StorePath(store=store, path=p),
4149-
)
4150-
)
4151-
4152-
return parents
4153-
4154-
41554107
FiltersLike: TypeAlias = (
41564108
Iterable[dict[str, JSON] | ArrayArrayCodec | Numcodec]
41574109
| ArrayArrayCodec

src/zarr/core/group.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
FiltersLike,
2929
SerializerLike,
3030
ShardsLike,
31-
_build_parents,
3231
_parse_deprecated_compressor,
3332
create_array,
3433
)
@@ -49,6 +48,7 @@
4948
)
5049
from zarr.core.config import config
5150
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
51+
from zarr.core.metadata.io import save_metadata
5252
from zarr.core.sync import SyncMixin, sync
5353
from zarr.errors import (
5454
ContainsArrayError,
@@ -818,22 +818,7 @@ async def get(
818818
return default
819819

820820
async def _save_metadata(self, ensure_parents: bool = False) -> None:
821-
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
822-
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
823-
824-
if ensure_parents:
825-
parents = _build_parents(self)
826-
for parent in parents:
827-
awaitables.extend(
828-
[
829-
(parent.store_path / key).set_if_not_exists(value)
830-
for key, value in parent.metadata.to_buffer_dict(
831-
default_buffer_prototype()
832-
).items()
833-
]
834-
)
835-
836-
await asyncio.gather(*awaitables)
821+
await save_metadata(self.store_path, self.metadata, ensure_parents=ensure_parents)
837822

838823
@property
839824
def path(self) -> str:

src/zarr/core/metadata/io.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import TYPE_CHECKING
5+
6+
from zarr.abc.store import set_or_delete
7+
from zarr.core.buffer.core import default_buffer_prototype
8+
from zarr.errors import ContainsArrayError
9+
from zarr.storage._common import StorePath, ensure_no_existing_node
10+
11+
if TYPE_CHECKING:
12+
from zarr.core.common import ZarrFormat
13+
from zarr.core.group import GroupMetadata
14+
from zarr.core.metadata import ArrayMetadata
15+
16+
17+
def _build_parents(store_path: StorePath, zarr_format: ZarrFormat) -> dict[str, GroupMetadata]:
18+
from zarr.core.group import GroupMetadata
19+
20+
path = store_path.path
21+
if not path:
22+
return {}
23+
24+
required_parts = path.split("/")[:-1]
25+
26+
# the root group
27+
parents = {"": GroupMetadata(zarr_format=zarr_format)}
28+
29+
for i, part in enumerate(required_parts):
30+
parent_path = "/".join(required_parts[:i] + [part])
31+
parents[parent_path] = GroupMetadata(zarr_format=zarr_format)
32+
33+
return parents
34+
35+
36+
async def save_metadata(
37+
store_path: StorePath, metadata: ArrayMetadata | GroupMetadata, ensure_parents: bool = False
38+
) -> None:
39+
"""Asynchronously save the array or group metadata.
40+
41+
Parameters
42+
----------
43+
store_path : StorePath
44+
Location to save metadata.
45+
metadata : ArrayMetadata | GroupMetadata
46+
Metadata to save.
47+
ensure_parents : bool, optional
48+
Create any missing parent groups, and check no existing parents are arrays.
49+
50+
Raises
51+
------
52+
ValueError
53+
"""
54+
to_save = metadata.to_buffer_dict(default_buffer_prototype())
55+
set_awaitables = [set_or_delete(store_path / key, value) for key, value in to_save.items()]
56+
57+
if ensure_parents:
58+
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
59+
parents = _build_parents(store_path, metadata.zarr_format)
60+
ensure_array_awaitables = []
61+
62+
for parent_path, parent_metadata in parents.items():
63+
parent_store_path = StorePath(store_path.store, parent_path)
64+
65+
# Error if an array already exists at any parent location. Only groups can have child nodes.
66+
ensure_array_awaitables.append(
67+
ensure_no_existing_node(
68+
parent_store_path, parent_metadata.zarr_format, node_type="array"
69+
)
70+
)
71+
set_awaitables.extend(
72+
[
73+
(parent_store_path / key).set_if_not_exists(value)
74+
for key, value in parent_metadata.to_buffer_dict(
75+
default_buffer_prototype()
76+
).items()
77+
]
78+
)
79+
80+
# Checks for parent arrays must happen first, before any metadata is modified
81+
try:
82+
await asyncio.gather(*ensure_array_awaitables)
83+
except ContainsArrayError as e:
84+
# clear awaitables to avoid RuntimeWarning: coroutine was never awaited
85+
for awaitable in set_awaitables:
86+
awaitable.close()
87+
88+
raise ValueError(
89+
f"A parent of {store_path} is an array - only groups may have child nodes."
90+
) from e
91+
92+
await asyncio.gather(*set_awaitables)

src/zarr/storage/_common.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,11 @@ def _is_fsspec_uri(uri: str) -> bool:
435435
return "://" in uri or ("::" in uri and "local://" not in uri)
436436

437437

438-
async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat) -> None:
438+
async def ensure_no_existing_node(
439+
store_path: StorePath,
440+
zarr_format: ZarrFormat,
441+
node_type: Literal["array", "group"] | None = None,
442+
) -> None:
439443
"""
440444
Check if a store_path is safe for array / group creation.
441445
Returns `None` or raises an exception.
@@ -446,6 +450,8 @@ async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat
446450
The storage location to check.
447451
zarr_format : ZarrFormat
448452
The Zarr format to check.
453+
node_type : str | None, optional
454+
Raise an error if an "array", or "group" exists. By default (when None), raises an error for either.
449455
450456
Raises
451457
------
@@ -456,16 +462,23 @@ async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat
456462
elif zarr_format == 3:
457463
extant_node = await _contains_node_v3(store_path)
458464

459-
if extant_node == "array":
460-
msg = f"An array exists in store {store_path.store!r} at path {store_path.path!r}."
461-
raise ContainsArrayError(msg)
462-
elif extant_node == "group":
463-
msg = f"An array exists in store {store_path.store!r} at path {store_path.path!r}."
464-
raise ContainsGroupError(msg)
465-
elif extant_node == "nothing":
466-
return
467-
msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable]
468-
raise ValueError(msg)
465+
match extant_node:
466+
case "array":
467+
if node_type != "group":
468+
msg = f"An array exists in store {store_path.store!r} at path {store_path.path!r}."
469+
raise ContainsArrayError(msg)
470+
471+
case "group":
472+
if node_type != "array":
473+
msg = f"A group exists in store {store_path.store!r} at path {store_path.path!r}."
474+
raise ContainsGroupError(msg)
475+
476+
case "nothing":
477+
return
478+
479+
case _:
480+
msg = f"Invalid value for extant_node: {extant_node}" # type: ignore[unreachable]
481+
raise ValueError(msg)
469482

470483

471484
async def _contains_node_v3(store_path: StorePath) -> Literal["array", "group", "nothing"]:

tests/test_group.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,24 @@ def test_group_create_array(
761761
assert np.array_equal(array[:], data)
762762

763763

764+
@pytest.mark.parametrize("method", ["create_array", "create_group"])
765+
def test_create_with_parent_array(store: Store, zarr_format: ZarrFormat, method: str):
766+
"""Test that groups/arrays cannot be created under a parent array."""
767+
768+
# create a group with a child array
769+
group = Group.from_store(store, zarr_format=zarr_format)
770+
group.create_array(name="arr_1", shape=(10, 10), dtype="uint8")
771+
772+
error_msg = r"A parent of .* is an array - only groups may have child nodes."
773+
if method == "create_array":
774+
with pytest.raises(ValueError, match=error_msg):
775+
group.create_array("arr_1/group_1/group_2/arr_2", shape=(10, 10), dtype="uint8")
776+
777+
else:
778+
with pytest.raises(ValueError, match=error_msg):
779+
group.create_group("arr_1/group_1/group_2/group_3")
780+
781+
764782
LikeMethodName = Literal["zeros_like", "ones_like", "empty_like", "full_like"]
765783

766784

tests/test_indexing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,14 +2110,14 @@ async def test_async_oindex(self, store, indexer, expected):
21102110

21112111
@pytest.mark.asyncio
21122112
async def test_async_oindex_with_zarr_array(self, store):
2113-
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2113+
group = zarr.create_group(store=store, zarr_format=3)
2114+
2115+
z1 = group.create_array(name="z1", shape=(2, 2), chunks=(1, 1), dtype="i8")
21142116
z1[...] = np.array([[1, 2], [3, 4]])
21152117
async_zarr = z1._async_array
21162118

21172119
# create boolean zarr array to index with
2118-
z2 = zarr.create_array(
2119-
store=store, name="z2", shape=(2,), chunks=(1,), zarr_format=3, dtype="?"
2120-
)
2120+
z2 = group.create_array(name="z2", shape=(2,), chunks=(1,), dtype="?")
21212121
z2[...] = np.array([True, False])
21222122

21232123
result = await async_zarr.oindex.getitem(z2)
@@ -2143,14 +2143,14 @@ async def test_async_vindex(self, store, indexer, expected):
21432143

21442144
@pytest.mark.asyncio
21452145
async def test_async_vindex_with_zarr_array(self, store):
2146-
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
2146+
group = zarr.create_group(store=store, zarr_format=3)
2147+
2148+
z1 = group.create_array(name="z1", shape=(2, 2), chunks=(1, 1), dtype="i8")
21472149
z1[...] = np.array([[1, 2], [3, 4]])
21482150
async_zarr = z1._async_array
21492151

21502152
# create boolean zarr array to index with
2151-
z2 = zarr.create_array(
2152-
store=store, name="z2", shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="?"
2153-
)
2153+
z2 = group.create_array(name="z2", shape=(2, 2), chunks=(1, 1), dtype="?")
21542154
z2[...] = np.array([[False, True], [False, True]])
21552155

21562156
result = await async_zarr.vindex.getitem(z2)

0 commit comments

Comments
 (0)