From 77115ffa2c5205e165814061f350bdf28fa78eae Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 16 Oct 2024 22:06:51 -0700 Subject: [PATCH 1/6] feature(group): add group setitem api --- src/zarr/api/asynchronous.py | 8 ++++++-- src/zarr/core/group.py | 10 ++++++++-- src/zarr/storage/zip.py | 5 ++++- tests/v3/test_group.py | 25 +++++++++++++++++++++++-- 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index e500562c4c..a8144dfe2e 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -396,12 +396,16 @@ async def save_array( mode = kwargs.pop("mode", None) store_path = await make_store_path(store, path=path, mode=mode, storage_options=storage_options) + if np.isscalar(arr): + arr = np.array(arr) + shape = arr.shape + chunks = getattr(arr, "chunks", shape) # for array-likes with chunks attribute new = await AsyncArray.create( store_path, zarr_format=zarr_format, - shape=arr.shape, + shape=shape, dtype=arr.dtype, - chunks=arr.shape, + chunks=chunks, **kwargs, ) await new.setitem(slice(None), arr) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index e25f70eef6..8ad54b79b2 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -596,6 +596,12 @@ def from_dict( store_path=store_path, ) + async def setitem(self, key: str, value: Any) -> None: + path = self.store_path / key + await async_api.save_array( + store=path, arr=value, zarr_format=self.metadata.zarr_format, exists_ok=True + ) + async def getitem( self, key: str, @@ -1369,8 +1375,8 @@ def __len__(self) -> int: return self.nmembers() def __setitem__(self, key: str, value: Any) -> None: - """__setitem__ is not supported in v3""" - raise NotImplementedError + """Create a new array""" + self._sync(self._async_group.setitem(key, value)) def __repr__(self) -> str: return f"" diff --git a/src/zarr/storage/zip.py b/src/zarr/storage/zip.py index c9cb579586..c235c723f0 100644 --- a/src/zarr/storage/zip.py +++ b/src/zarr/storage/zip.py @@ -219,7 +219,10 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: async def delete(self, key: str) -> None: # docstring inherited - raise NotImplementedError + # we choose to only raise NotImplementedError here if the key exists + # this allows the array/group APIs to avoid the overhead of existence checks + if await self.exists(key): + raise NotImplementedError async def exists(self, key: str) -> bool: # docstring inherited diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 4f062d5316..a1f0142124 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -391,8 +391,29 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: Test the `Group.__setitem__` method. """ group = Group.from_store(store, zarr_format=zarr_format) - with pytest.raises(NotImplementedError): - group["key"] = 10 + arr = np.ones((2, 4)) + group["key"] = arr + assert group["key"].shape == (2, 4) + np.testing.assert_array_equal(group["key"][:], arr) + + if store.supports_deletes: + key = "key" + else: + # overwriting with another array requires deletes + # for stores that don't support this, we just use a new key + key = "key2" + + # overwrite with another array + arr = np.zeros((3, 5)) + group[key] = arr + assert group[key].shape == (3, 5) + np.testing.assert_array_equal(group[key], arr) + + # overwrite with a scalar + # separate bug! + # group["key"] = 1.5 + # assert group["key"].shape == () + # assert group["key"][:] == 1 def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None: From ed67cc327e77d79aa7700381142a59c8d67952de Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Thu, 17 Oct 2024 06:46:28 -0700 Subject: [PATCH 2/6] arrays proxy --- src/zarr/core/group.py | 56 +++++++++++++++++++++++++++++++++++++----- tests/v3/test_group.py | 31 +++++++++++++++++++++-- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8ad54b79b2..f897eb696b 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -294,6 +294,36 @@ def flatten( return metadata +class ArraysProxy: + """ + Proxy for arrays in a group. + + Used to implement the `Group.arrays` property + """ + + def __init__(self, group: Group) -> None: + self._group = group + + def __getitem__(self, key: str) -> Array: + obj = self._group[key] + if isinstance(obj, Array): + return obj + raise KeyError(key) + + def __setitem__(self, key: str, value: npt.ArrayLike) -> None: + """ + Set an array in the group. + """ + self._group._sync(self._group._async_group.set_array(key, value)) + + def __iter__(self) -> Generator[tuple[str, Array], None]: + for name, async_array in self._group._sync_iter(self._group._async_group.arrays()): + yield name, Array(async_array) + + def __call__(self) -> Generator[tuple[str, Array], None]: + return iter(self) + + @dataclass(frozen=True) class GroupMetadata(Metadata): attributes: dict[str, Any] = field(default_factory=dict) @@ -596,7 +626,16 @@ def from_dict( store_path=store_path, ) - async def setitem(self, key: str, value: Any) -> None: + async def set_array(self, key: str, value: Any) -> None: + """fastpath for creating a new array + + Parameters + ---------- + key : str + Array name + value : array-like + Array data + """ path = self.store_path / key await async_api.save_array( store=path, arr=value, zarr_format=self.metadata.zarr_format, exists_ok=True @@ -1374,9 +1413,14 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return self.nmembers() + @deprecated("Use Group.arrays setter instead.") def __setitem__(self, key: str, value: Any) -> None: - """Create a new array""" - self._sync(self._async_group.setitem(key, value)) + """Create a new array + + .. deprecated:: 3.0.0 + Use Group.arrays.setter instead. + """ + self._sync(self._async_group.set_array(key, value)) def __repr__(self) -> str: return f"" @@ -1473,9 +1517,9 @@ def group_values(self) -> Generator[Group, None]: for _, group in self.groups(): yield group - def arrays(self) -> Generator[tuple[str, Array], None]: - for name, async_array in self._sync_iter(self._async_group.arrays()): - yield name, Array(async_array) + @property + def arrays(self) -> ArraysProxy: + return ArraysProxy(self) def array_keys(self) -> Generator[str, None]: for name, _ in self.arrays(): diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index a1f0142124..19dff6abf0 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -392,7 +392,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: """ group = Group.from_store(store, zarr_format=zarr_format) arr = np.ones((2, 4)) - group["key"] = arr + with pytest.warns(DeprecationWarning): + group["key"] = arr assert group["key"].shape == (2, 4) np.testing.assert_array_equal(group["key"][:], arr) @@ -405,7 +406,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: # overwrite with another array arr = np.zeros((3, 5)) - group[key] = arr + with pytest.warns(DeprecationWarning): + group[key] = arr assert group[key].shape == (3, 5) np.testing.assert_array_equal(group[key], arr) @@ -416,6 +418,31 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: # assert group["key"][:] == 1 +def test_group_arrays_setter(store: Store, zarr_format: ZarrFormat) -> None: + """ + Test the `Group.__setitem__` method. + """ + group = Group.from_store(store, zarr_format=zarr_format) + arr = np.ones((2, 4)) + group.arrays["key"] = arr + assert group["key"].shape == (2, 4) + np.testing.assert_array_equal(group["key"][:], arr) + + if store.supports_deletes: + key = "key" + else: + # overwriting with another array requires deletes + # for stores that don't support this, we just use a new key + key = "key2" + + # overwrite with another array + arr = np.zeros((3, 5)) + with pytest.warns(DeprecationWarning): + group[key] = arr + assert group[key].shape == (3, 5) + np.testing.assert_array_equal(group[key], arr) + + def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None: """ Test the `Group.__contains__` method From 4ab26981524408fb4fba053ef32322d8d9400109 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Oct 2024 15:02:25 -0700 Subject: [PATCH 3/6] rollback to simple version --- src/zarr/api/asynchronous.py | 2 +- tests/test_group.py | 31 ------------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 8e99727de1..cd8c3543ca 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -399,7 +399,7 @@ async def save_array( if np.isscalar(arr): arr = np.array(arr) shape = arr.shape - chunks = getattr(arr, "chunks", shape) # for array-likes with chunks attribute + chunks = getattr(arr, "chunks", None) # for array-likes with chunks attribute new = await AsyncArray.create( store_path, zarr_format=zarr_format, diff --git a/tests/test_group.py b/tests/test_group.py index d69f8540a5..f79adf5928 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -447,37 +447,6 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: assert group[key].shape == (3, 5) np.testing.assert_array_equal(group[key], arr) - # overwrite with a scalar - # separate bug! - # group["key"] = 1.5 - # assert group["key"].shape == () - # assert group["key"][:] == 1 - - -def test_group_arrays_setter(store: Store, zarr_format: ZarrFormat) -> None: - """ - Test the `Group.__setitem__` method. - """ - group = Group.from_store(store, zarr_format=zarr_format) - arr = np.ones((2, 4)) - group.arrays["key"] = arr - assert group["key"].shape == (2, 4) - np.testing.assert_array_equal(group["key"][:], arr) - - if store.supports_deletes: - key = "key" - else: - # overwriting with another array requires deletes - # for stores that don't support this, we just use a new key - key = "key2" - - # overwrite with another array - arr = np.zeros((3, 5)) - with pytest.warns(DeprecationWarning): - group[key] = arr - assert group[key].shape == (3, 5) - np.testing.assert_array_equal(group[key], arr) - def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None: """ From 056398ea90ee202b60ccd31982c21c198bb31bbf Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Oct 2024 15:32:41 -0700 Subject: [PATCH 4/6] rollback deprecation --- src/zarr/api/asynchronous.py | 6 ++--- src/zarr/core/group.py | 50 ++++++++---------------------------- tests/test_group.py | 10 ++++---- 3 files changed, 18 insertions(+), 48 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index cd8c3543ca..fc15c75c9e 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -244,7 +244,7 @@ async def load( obj = await open(store=store, path=path, zarr_format=zarr_format) if isinstance(obj, AsyncArray): - return await obj.getitem(slice(None)) + return await obj.getitem(...) else: raise NotImplementedError("loading groups not yet supported") @@ -408,7 +408,7 @@ async def save_array( chunks=chunks, **kwargs, ) - await new.setitem(slice(None), arr) + await new.setitem(..., arr) async def save_group( @@ -520,7 +520,7 @@ async def array( z = await create(**kwargs) # fill with data - await z.setitem(slice(None), data) + await z.setitem(..., data) return z diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index e93d358e55..46f37700eb 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -295,36 +295,6 @@ def flatten( return metadata -class ArraysProxy: - """ - Proxy for arrays in a group. - - Used to implement the `Group.arrays` property - """ - - def __init__(self, group: Group) -> None: - self._group = group - - def __getitem__(self, key: str) -> Array: - obj = self._group[key] - if isinstance(obj, Array): - return obj - raise KeyError(key) - - def __setitem__(self, key: str, value: npt.ArrayLike) -> None: - """ - Set an array in the group. - """ - self._group._sync(self._group._async_group.set_array(key, value)) - - def __iter__(self) -> Generator[tuple[str, Array], None]: - for name, async_array in self._group._sync_iter(self._group._async_group.arrays()): - yield name, Array(async_array) - - def __call__(self) -> Generator[tuple[str, Array], None]: - return iter(self) - - @dataclass(frozen=True) class GroupMetadata(Metadata): attributes: dict[str, Any] = field(default_factory=dict) @@ -630,8 +600,10 @@ def from_dict( store_path=store_path, ) - async def set_array(self, key: str, value: Any) -> None: - """fastpath for creating a new array + async def setitem(self, key: str, value: Any) -> None: + """Fastpath for creating a new array + + New arrays will be created with default array settings for the array type. Parameters ---------- @@ -1438,14 +1410,12 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return self.nmembers() - @deprecated("Use Group.arrays setter instead.") def __setitem__(self, key: str, value: Any) -> None: - """Create a new array + """Fastpath for creating a new array. - .. deprecated:: 3.0.0 - Use Group.arrays.setter instead. + New arrays will be created using default settings for the array type. """ - self._sync(self._async_group.set_array(key, value)) + self._sync(self._async_group.setitem(key, value)) def __repr__(self) -> str: return f"" @@ -1542,9 +1512,9 @@ def group_values(self) -> Generator[Group, None]: for _, group in self.groups(): yield group - @property - def arrays(self) -> ArraysProxy: - return ArraysProxy(self) + def arrays(self) -> Generator[tuple[str, Array], None]: + for name, async_array in self._sync_iter(self._async_group.arrays()): + yield name, Array(async_array) def array_keys(self) -> Generator[str, None]: for name, _ in self.arrays(): diff --git a/tests/test_group.py b/tests/test_group.py index f79adf5928..409eab838d 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -424,12 +424,12 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None: def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: """ - Test the `Group.__setitem__` method. + Test the `Group.__setitem__` method. (Deprecated) """ group = Group.from_store(store, zarr_format=zarr_format) arr = np.ones((2, 4)) - with pytest.warns(DeprecationWarning): - group["key"] = arr + group["key"] = arr + assert list(group.array_keys()) == ["key"] assert group["key"].shape == (2, 4) np.testing.assert_array_equal(group["key"][:], arr) @@ -442,8 +442,8 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: # overwrite with another array arr = np.zeros((3, 5)) - with pytest.warns(DeprecationWarning): - group[key] = arr + group[key] = arr + assert key in list(group.array_keys()) assert group[key].shape == (3, 5) np.testing.assert_array_equal(group[key], arr) From 804cf12a37373e88766623cd32da002b58f31a24 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Oct 2024 15:34:29 -0700 Subject: [PATCH 5/6] rollback ... --- src/zarr/api/asynchronous.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index fc15c75c9e..cd8c3543ca 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -244,7 +244,7 @@ async def load( obj = await open(store=store, path=path, zarr_format=zarr_format) if isinstance(obj, AsyncArray): - return await obj.getitem(...) + return await obj.getitem(slice(None)) else: raise NotImplementedError("loading groups not yet supported") @@ -408,7 +408,7 @@ async def save_array( chunks=chunks, **kwargs, ) - await new.setitem(..., arr) + await new.setitem(slice(None), arr) async def save_group( @@ -520,7 +520,7 @@ async def array( z = await create(**kwargs) # fill with data - await z.setitem(..., data) + await z.setitem(slice(None), data) return z From 2dafa7efa73fce28c5b7f773e98a3e20c417e420 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Thu, 24 Oct 2024 15:43:05 -0700 Subject: [PATCH 6/6] Update tests/test_group.py --- tests/test_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_group.py b/tests/test_group.py index 409eab838d..bcdc6ff0da 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -424,7 +424,7 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None: def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: """ - Test the `Group.__setitem__` method. (Deprecated) + Test the `Group.__setitem__` method. """ group = Group.from_store(store, zarr_format=zarr_format) arr = np.ones((2, 4))