Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add enter and exit methods to groups. #2691

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Iterator, Mapping
from dataclasses import asdict, dataclass, field, fields, replace
from itertools import accumulate
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -63,7 +63,9 @@
Coroutine,
Generator,
Iterable,
Iterator,
)
from types import TracebackType
from typing import Any

from zarr.core.array_spec import ArrayConfig, ArrayConfigLike
Expand Down Expand Up @@ -1812,6 +1814,14 @@
obj = sync(AsyncGroup.open(store, zarr_format=zarr_format))
return cls(obj)

def __enter__(self) -> Self:
return self

Check warning on line 1818 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1818

Added line #L1818 was not covered by tests

def __exit__(
self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None
) -> None:
self.store.close()

Check warning on line 1823 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1823

Added line #L1823 was not covered by tests

def __getitem__(self, path: str) -> Array | Group:
"""Obtain a group member.

Expand Down
22 changes: 22 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,28 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None:
assert len(group) == 0


def test_group_with_context_manager(store: Store, zarr_format: ZarrFormat, overwrite: bool) -> None:
spath = StorePath(store)

# attempt to open a group that does not exist.
with pytest.raises(FileNotFoundError):
with zarr.open_group(store, mode="r") as group:
pass

attrs = {"path": "foo"}

with zarr.create_group(
store, attributes=attrs, zarr_format=zarr_format, overwrite=overwrite
) as group:
assert store._is_open
assert group.attrs == attrs
assert group.metadata.zarr_format == zarr_format
assert group.store_path == spath

# Check if store was closed after exit.
assert not store._is_open


def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the `Group.__setitem__` method.
Expand Down