diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8e7f7f3474..5f8d6e3186 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -9,7 +9,7 @@ from collections.abc import Iterator, Mapping from dataclasses import asdict, dataclass, field, fields, replace from itertools import accumulate -from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload +from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload import numpy as np import numpy.typing as npt @@ -63,7 +63,9 @@ Coroutine, Generator, Iterable, + Iterator, ) + from types import TracebackType from typing import Any from zarr.core.array_spec import ArrayConfig, ArrayConfigLike @@ -1812,6 +1814,14 @@ def open( obj = sync(AsyncGroup.open(store, zarr_format=zarr_format)) return cls(obj) + def __enter__(self) -> Self: + return self + + def __exit__( + self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None + ) -> None: + self.store.close() + def __getitem__(self, path: str) -> Array | Group: """Obtain a group member. diff --git a/tests/test_group.py b/tests/test_group.py index 1e4f31b5d6..c598c46d38 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -450,6 +450,28 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None: assert len(group) == 0 +def test_group_with_context_manager(store: Store, zarr_format: ZarrFormat, overwrite: bool) -> None: + spath = StorePath(store) + + # attempt to open a group that does not exist. + with pytest.raises(FileNotFoundError): + with zarr.open_group(store, mode="r") as group: + pass + + attrs = {"path": "foo"} + + with zarr.create_group( + store, attributes=attrs, zarr_format=zarr_format, overwrite=overwrite + ) as group: + assert store._is_open + assert group.attrs == attrs + assert group.metadata.zarr_format == zarr_format + assert group.store_path == spath + + # Check if store was closed after exit. + assert not store._is_open + + def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: """ Test the `Group.__setitem__` method.