Skip to content

add enter and exit methods to groups. #2691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections import defaultdict
from dataclasses import asdict, dataclass, field, fields, replace
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -55,6 +55,7 @@

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Iterable, Iterator
from types import TracebackType
from typing import Any

from zarr.core.array_spec import ArrayConfig, ArrayConfigLike
Expand Down Expand Up @@ -1752,6 +1753,14 @@ def open(
obj = sync(AsyncGroup.open(store, zarr_format=zarr_format))
return cls(obj)

def __enter__(self) -> Self:
return self

def __exit__(
self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None
) -> None:
self.store.close()

def __getitem__(self, path: str) -> Array | Group:
"""Obtain a group member.

Expand Down
22 changes: 22 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,28 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None:
assert len(group) == 0


def test_group_with_context_manager(store: Store, zarr_format: ZarrFormat, overwrite: bool) -> None:
spath = StorePath(store)

# attempt to open a group that does not exist.
with pytest.raises(FileNotFoundError):
with Group.open(store) as store:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with Group.open(store) as store:
with Group.open(store) as g:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in 3aee7b0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, instead of using Group.open and Group.from_store, can we use zarr.open_group and zarr.create_group? These are the public facing API methods we want most folks to be using.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in 3aee7b0

pass

attrs = {"path": "foo"}

with Group.from_store(
store, attributes=attrs, zarr_format=zarr_format, overwrite=overwrite
) as group:
assert store._is_open
assert group.attrs == attrs
assert group.metadata.zarr_format == zarr_format
assert group.store_path == spath

# Check if store was closed after exit.
assert not store._is_open


def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the `Group.__setitem__` method.
Expand Down
Loading