Skip to content

Commit 15a9747

Browse files
authored
Resolve Mypy erorrs in v3 branch (zarr-developers#1692)
* refactor(v3): Using appropriate types * fix(v3): Typing fixes + minor code fixes * fix(v3): _sync_iter works with coroutines * docs(v3/store/core.py): clearer comment * fix(metadata.py): Use Any outside TYPE_CHECKING for Pydantic * fix(zarr/v3): correct zarr format + remove unused method * fix(v3/store/core.py): Potential suggestion on handling str store_like * refactor(zarr/v3): Add more typing * ci(.pre-commit-config.yaml): zarr v3 mypy checks turned on in pre-commit
1 parent 76c3450 commit 15a9747

11 files changed

+47
-50
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ repos:
3131
hooks:
3232
- id: mypy
3333
files: src
34-
exclude: ^src/zarr/v3
3534
args: []
3635
additional_dependencies:
3736
- types-redis

src/zarr/v3/abc/metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from typing import Dict
66
from typing_extensions import Self
77

8-
from dataclasses import fields
8+
from dataclasses import fields, dataclass
99

1010
from zarr.v3.common import JSON
1111

1212

13+
@dataclass(frozen=True)
1314
class Metadata:
1415
def to_dict(self) -> JSON:
1516
"""

src/zarr/v3/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def shape(self) -> ChunkCoords:
182182

183183
@property
184184
def size(self) -> int:
185-
return np.prod(self.metadata.shape)
185+
return np.prod(self.metadata.shape).item()
186186

187187
@property
188188
def dtype(self) -> np.dtype:

src/zarr/v3/chunk_grids.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class ChunkGrid(Metadata):
2020
@classmethod
2121
def from_dict(cls, data: Dict[str, JSON]) -> ChunkGrid:
2222
if isinstance(data, ChunkGrid):
23-
return data # type: ignore
23+
return data
2424

2525
name_parsed, _ = parse_named_configuration(data)
2626
if name_parsed == "regular":

src/zarr/v3/chunk_key_encodings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
from abc import abstractmethod
3-
from typing import TYPE_CHECKING, Dict, Literal
3+
from typing import TYPE_CHECKING, Dict, Literal, cast
44
from dataclasses import dataclass
55
from zarr.v3.abc.metadata import Metadata
66

@@ -19,7 +19,7 @@
1919
def parse_separator(data: JSON) -> SeparatorLiteral:
2020
if data not in (".", "/"):
2121
raise ValueError(f"Expected an '.' or '/' separator. Got {data} instead.")
22-
return data # type: ignore
22+
return cast(SeparatorLiteral, data)
2323

2424

2525
@dataclass(frozen=True)
@@ -35,7 +35,7 @@ def __init__(self, *, separator: SeparatorLiteral) -> None:
3535
@classmethod
3636
def from_dict(cls, data: Dict[str, JSON]) -> ChunkKeyEncoding:
3737
if isinstance(data, ChunkKeyEncoding):
38-
return data # type: ignore
38+
return data
3939

4040
name_parsed, configuration_parsed = parse_named_configuration(data)
4141
if name_parsed == "default":

src/zarr/v3/codecs/transpose.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, Dict, Iterable
2+
from typing import TYPE_CHECKING, Dict, Iterable, Union, cast
33

44
from dataclasses import dataclass, replace
55

@@ -16,12 +16,12 @@
1616
from zarr.v3.codecs.registry import register_codec
1717

1818

19-
def parse_transpose_order(data: JSON) -> Tuple[int]:
19+
def parse_transpose_order(data: Union[JSON, Iterable[int]]) -> Tuple[int, ...]:
2020
if not isinstance(data, Iterable):
2121
raise TypeError(f"Expected an iterable. Got {data} instead.")
2222
if not all(isinstance(a, int) for a in data):
2323
raise TypeError(f"Expected an iterable of integers. Got {data} instead.")
24-
return tuple(data) # type: ignore[return-value]
24+
return tuple(cast(Iterable[int], data))
2525

2626

2727
@dataclass(frozen=True)
@@ -31,7 +31,7 @@ class TransposeCodec(ArrayArrayCodec):
3131
order: Tuple[int, ...]
3232

3333
def __init__(self, *, order: ChunkCoordsLike) -> None:
34-
order_parsed = parse_transpose_order(order) # type: ignore[arg-type]
34+
order_parsed = parse_transpose_order(order)
3535

3636
object.__setattr__(self, "order", order_parsed)
3737

src/zarr/v3/group.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
import json
66
import logging
7-
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, Iterator, List
7+
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, List
88
from zarr.v3.abc.metadata import Metadata
99

1010
from zarr.v3.array import AsyncArray, Array
@@ -46,11 +46,11 @@ def to_bytes(self) -> Dict[str, bytes]:
4646
return {ZARR_JSON: json.dumps(self.to_dict()).encode()}
4747
else:
4848
return {
49-
ZGROUP_JSON: self.zarr_format,
49+
ZGROUP_JSON: json.dumps({"zarr_format": 2}).encode(),
5050
ZATTRS_JSON: json.dumps(self.attributes).encode(),
5151
}
5252

53-
def __init__(self, attributes: Dict[str, Any] = None, zarr_format: Literal[2, 3] = 3):
53+
def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Literal[2, 3] = 3):
5454
attributes_parsed = parse_attributes(attributes)
5555
zarr_format_parsed = parse_zarr_format(zarr_format)
5656

@@ -104,7 +104,7 @@ async def open(
104104
zarr_format: Literal[2, 3] = 3,
105105
) -> AsyncGroup:
106106
store_path = make_store_path(store)
107-
zarr_json_bytes = await (store_path / ZARR_JSON).get_async()
107+
zarr_json_bytes = await (store_path / ZARR_JSON).get()
108108
assert zarr_json_bytes is not None
109109

110110
# TODO: consider trying to autodiscover the zarr-format here
@@ -139,7 +139,7 @@ def from_dict(
139139
store_path: StorePath,
140140
data: Dict[str, Any],
141141
runtime_configuration: RuntimeConfiguration,
142-
) -> Group:
142+
) -> AsyncGroup:
143143
group = cls(
144144
metadata=GroupMetadata.from_dict(data),
145145
store_path=store_path,
@@ -168,10 +168,12 @@ async def getitem(
168168
zarr_json = json.loads(zarr_json_bytes)
169169
if zarr_json["node_type"] == "group":
170170
return type(self).from_dict(store_path, zarr_json, self.runtime_configuration)
171-
if zarr_json["node_type"] == "array":
171+
elif zarr_json["node_type"] == "array":
172172
return AsyncArray.from_dict(
173173
store_path, zarr_json, runtime_configuration=self.runtime_configuration
174174
)
175+
else:
176+
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
175177
elif self.metadata.zarr_format == 2:
176178
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
177179
# This guarantees that we will always make at least one extra request to the store
@@ -271,7 +273,7 @@ def __repr__(self):
271273
async def nchildren(self) -> int:
272274
raise NotImplementedError
273275

274-
async def children(self) -> AsyncIterator[AsyncArray, AsyncGroup]:
276+
async def children(self) -> AsyncIterator[Union[AsyncArray, AsyncGroup]]:
275277
raise NotImplementedError
276278

277279
async def contains(self, child: str) -> bool:
@@ -381,8 +383,12 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group
381383
new_metadata = replace(self.metadata, attributes=new_attributes)
382384

383385
# Write new metadata
384-
await (self.store_path / ZARR_JSON).set_async(new_metadata.to_bytes())
385-
return replace(self, metadata=new_metadata)
386+
to_save = new_metadata.to_bytes()
387+
awaitables = [(self.store_path / key).set(value) for key, value in to_save.items()]
388+
await asyncio.gather(*awaitables)
389+
390+
async_group = replace(self._async_group, metadata=new_metadata)
391+
return replace(self, _async_group=async_group)
386392

387393
@property
388394
def metadata(self) -> GroupMetadata:
@@ -396,34 +402,38 @@ def attrs(self) -> Attributes:
396402
def info(self):
397403
return self._async_group.info
398404

405+
@property
406+
def store_path(self) -> StorePath:
407+
return self._async_group.store_path
408+
399409
def update_attributes(self, new_attributes: Dict[str, Any]):
400410
self._sync(self._async_group.update_attributes(new_attributes))
401411
return self
402412

403413
@property
404414
def nchildren(self) -> int:
405-
return self._sync(self._async_group.nchildren)
415+
return self._sync(self._async_group.nchildren())
406416

407417
@property
408-
def children(self) -> List[Array, Group]:
409-
_children = self._sync_iter(self._async_group.children)
418+
def children(self) -> List[Union[Array, Group]]:
419+
_children = self._sync_iter(self._async_group.children())
410420
return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children]
411421

412422
def __contains__(self, child) -> bool:
413423
return self._sync(self._async_group.contains(child))
414424

415-
def group_keys(self) -> Iterator[str]:
416-
return self._sync_iter(self._async_group.group_keys)
425+
def group_keys(self) -> List[str]:
426+
return self._sync_iter(self._async_group.group_keys())
417427

418428
def groups(self) -> List[Group]:
419429
# TODO: in v2 this was a generator that return key: Group
420-
return [Group(obj) for obj in self._sync_iter(self._async_group.groups)]
430+
return [Group(obj) for obj in self._sync_iter(self._async_group.groups())]
421431

422432
def array_keys(self) -> List[str]:
423-
return self._sync_iter(self._async_group.array_keys)
433+
return self._sync_iter(self._async_group.array_keys())
424434

425435
def arrays(self) -> List[Array]:
426-
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays)]
436+
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())]
427437

428438
def tree(self, expand=False, level=None) -> Any:
429439
return self._sync(self._async_group.tree(expand=expand, level=level))

src/zarr/v3/metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
from enum import Enum
3-
from typing import TYPE_CHECKING, cast, Dict, Iterable
3+
from typing import TYPE_CHECKING, cast, Dict, Iterable, Any
44
from dataclasses import dataclass, field
55
import json
66
import numpy as np
@@ -10,7 +10,7 @@
1010

1111

1212
if TYPE_CHECKING:
13-
from typing import Any, Literal, Union, List, Optional, Tuple
13+
from typing import Literal, Union, List, Optional, Tuple
1414
from zarr.v3.codecs.pipeline import CodecPipeline
1515

1616

@@ -244,7 +244,7 @@ class ArrayV2Metadata(Metadata):
244244
filters: Optional[List[Dict[str, Any]]] = None
245245
dimension_separator: Literal[".", "/"] = "."
246246
compressor: Optional[Dict[str, Any]] = None
247-
attributes: Optional[Dict[str, Any]] = field(default_factory=dict)
247+
attributes: Optional[Dict[str, Any]] = cast(Dict[str, Any], field(default_factory=dict))
248248
zarr_format: Literal[2] = field(init=False, default=2)
249249

250250
def __init__(

src/zarr/v3/store/core.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from zarr.v3.common import BytesLike
77
from zarr.v3.abc.store import Store
8+
from zarr.v3.store.local import LocalStore
89

910

1011
def _dereference_path(root: str, path: str) -> str:
@@ -24,10 +25,6 @@ def __init__(self, store: Store, path: Optional[str] = None):
2425
self.store = store
2526
self.path = path or ""
2627

27-
@classmethod
28-
def from_path(cls, pth: Path) -> StorePath:
29-
return cls(Store.from_path(pth))
30-
3128
async def get(
3229
self, byte_range: Optional[Tuple[int, Optional[int]]] = None
3330
) -> Optional[BytesLike]:
@@ -70,14 +67,6 @@ def make_store_path(store_like: StoreLike) -> StorePath:
7067
return store_like
7168
elif isinstance(store_like, Store):
7269
return StorePath(store_like)
73-
# elif isinstance(store_like, Path):
74-
# return StorePath(Store.from_path(store_like))
7570
elif isinstance(store_like, str):
76-
try:
77-
from upath import UPath
78-
79-
return StorePath(Store.from_path(UPath(store_like)))
80-
except ImportError as e:
81-
raise e
82-
# return StorePath(LocalStore(Path(store_like)))
71+
return StorePath(LocalStore(Path(store_like)))
8372
raise TypeError

src/zarr/v3/store/local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async def list_prefix(self, prefix: str) -> List[str]:
146146
"""
147147

148148
def _list_prefix(root: Path, prefix: str) -> List[str]:
149-
files = [p for p in (root / prefix).rglob("*") if p.is_file()]
149+
files = [str(p) for p in (root / prefix).rglob("*") if p.is_file()]
150150
return files
151151

152152
return await to_thread(_list_prefix, self.root, prefix)

src/zarr/v3/sync.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import (
66
Any,
77
AsyncIterator,
8-
Callable,
98
Coroutine,
109
List,
1110
Optional,
@@ -112,11 +111,10 @@ def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
112111
# this should allow us to better type the sync wrapper
113112
return sync(coroutine, loop=self._sync_configuration.asyncio_loop)
114113

115-
def _sync_iter(
116-
self, func: Callable[P, AsyncIterator[T]], *args: P.args, **kwargs: P.kwargs
117-
) -> List[T]:
114+
def _sync_iter(self, coroutine: Coroutine[Any, Any, AsyncIterator[T]]) -> List[T]:
118115
async def iter_to_list() -> List[T]:
119116
# TODO: replace with generators so we don't materialize the entire iterator at once
120-
return [item async for item in func(*args, **kwargs)]
117+
async_iterator = await coroutine
118+
return [item async for item in async_iterator]
121119

122120
return self._sync(iter_to_list())

0 commit comments

Comments
 (0)