Skip to content

Commit a3d29f8

Browse files
shoyerXarray-Beam authors
authored andcommitted
Allow using ... as a key in chunk specifications.
This change enables specifying a default chunk size for all dimensions not explicitly listed in the `chunks` mapping by using `...` as a key. For example, `{'x': 10, ...: 20}` will chunk dimension 'x' into sizes of 10 and all other dimensions into sizes of 20. PiperOrigin-RevId: 814430585
1 parent 69f05a7 commit a3d29f8

File tree

5 files changed

+222
-83
lines changed

5 files changed

+222
-83
lines changed

examples/xbeam_rechunk.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Rechunk a Zarr dataset."""
15+
import types
16+
1517
from absl import app
1618
from absl import flags
1719
import apache_beam as beam
@@ -51,22 +53,38 @@
5153
# pylint: disable=expression-not-assigned
5254

5355

54-
def _parse_chunks_str(chunks_str: str) -> dict[str, int]:
56+
def _try_to_int(chunks_str: str) -> int | str:
57+
try:
58+
return int(chunks_str)
59+
except ValueError:
60+
return chunks_str
61+
62+
63+
def _parse_chunks_flag(
64+
chunks_str: str,
65+
) -> dict[str | types.EllipsisType, int | str] | int | str:
66+
"""Parse a string representation of unnormalized chunks."""
67+
if '=' not in chunks_str:
68+
return _try_to_int(chunks_str)
69+
5570
chunks = {}
5671
parts = chunks_str.split(',')
5772
for part in parts:
5873
k, v = part.split('=')
59-
chunks[k] = int(v)
74+
if k == '...':
75+
k = ...
76+
chunks[k] = _try_to_int(v)
6077
return chunks
6178

6279

6380
def main(argv):
64-
target_chunks = _parse_chunks_str(TARGET_CHUNKS.value)
81+
target_chunks = _parse_chunks_flag(TARGET_CHUNKS.value)
6582

66-
if TARGET_SHARDS.value is not None:
67-
target_shards = _parse_chunks_str(TARGET_SHARDS.value)
68-
else:
69-
target_shards = None
83+
target_shards = (
84+
_parse_chunks_flag(TARGET_SHARDS.value)
85+
if TARGET_SHARDS.value is not None
86+
else None
87+
)
7088

7189
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
7290
root |= (

examples/xbeam_rechunk_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from absl.testing import absltest
1717
from absl.testing import flagsaver
1818
import xarray
19+
from xarray_beam._src import test_util
1920

2021
from . import xbeam_rechunk
21-
from xarray_beam._src import test_util
2222

2323

2424
class Era5RechunkTest(test_util.TestCase):
@@ -27,20 +27,25 @@ def test_chunks_only(self):
2727
input_path = self.create_tempdir('source').full_path
2828
output_path = self.create_tempdir('destination').full_path
2929

30-
input_ds = test_util.dummy_era5_surface_dataset(times=365)
30+
input_ds = test_util.dummy_era5_surface_dataset(
31+
latitudes=100, longitudes=200, times=365
32+
)
3133
input_ds.chunk({'time': 31}).to_zarr(input_path)
3234

3335
with flagsaver.flagsaver(
3436
input_path=input_path,
3537
output_path=output_path,
36-
target_chunks='latitude=5,longitude=5,time=-1',
38+
target_chunks=f'time=-1,...={365*10*20*4}B',
3739
):
3840
xbeam_rechunk.main([])
3941

4042
output_ds = xarray.open_zarr(output_path)
43+
# dask.array tries to preserve the aspect ratio of the original array when
44+
# splitting across dimensions, hence the 2x ratio between latitude and
45+
# longitude.
4146
self.assertEqual(
4247
{k: v[0] for k, v in output_ds.chunks.items()},
43-
{'latitude': 5, 'longitude': 5, 'time': 365}
48+
{'latitude': 10, 'longitude': 20, 'time': 365},
4449
)
4550
xarray.testing.assert_identical(input_ds, output_ds)
4651

@@ -63,7 +68,7 @@ def test_chunks_and_shards(self):
6368
output_ds = xarray.open_zarr(output_path)
6469
self.assertEqual(
6570
{k: v[0] for k, v in output_ds.chunks.items()},
66-
{'latitude': 5, 'longitude': 5, 'time': 365}
71+
{'latitude': 5, 'longitude': 5, 'time': 365},
6772
)
6873
actual_shards = {k: v.encoding['shards'] for k, v in output_ds.items()}
6974
expected_shards = {k: (365, 10, 10) for k, v in output_ds.items()}

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.10.2' # automatically synchronized to pyproject.toml
58+
__version__ = '0.10.3' # automatically synchronized to pyproject.toml

xarray_beam/_src/dataset.py

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import os.path
3939
import tempfile
4040
import textwrap
41+
import types
4142
from typing import Callable, Literal
4243

4344
import apache_beam as beam
@@ -71,16 +72,19 @@ def _to_human_size(nbytes: int) -> str:
7172
return f'{_at_least_two_digits(nbytes)}EB'
7273

7374

75+
UnnormalizedChunks = Mapping[str | types.EllipsisType, int | str] | int | str
76+
77+
7478
def normalize_chunks(
75-
chunks: Mapping[str, int | str] | str,
79+
chunks: UnnormalizedChunks,
7680
template: xarray.Dataset,
7781
split_vars: bool = False,
7882
previous_chunks: Mapping[str, int] | None = None,
7983
) -> dict[str, int]:
8084
"""Normalize chunks for a xarray.Dataset.
8185
82-
This function interprets various chunk specifications (e.g., -1, 'auto',
83-
byte-strings) and returns a dictionary mapping dimension names to
86+
This function interprets various chunk specifications (e.g., integer sizes or
87+
numbers of bytes) and returns a dictionary mapping dimension names to
8488
concrete integer chunk sizes. It uses ``dask.array.api.normalize_chunks``
8589
under the hood.
8690
@@ -89,19 +93,28 @@ def normalize_chunks(
8993
dimension.
9094
- An integer: the exact chunk size for this dimension.
9195
- A byte-string (e.g., "64MiB", "1GB"): indicates that dask should pick
92-
chunk sizes to aim for chunks of approximately this size. If byte limits
93-
are specified for multiple dimensions, they must be consistent (i.e.,
94-
parse to the same number of bytes).
95-
- ``'auto'``: chunks will be automatically determined for all 'auto'
96-
dimensions to ensure chunks are approximately the target number of bytes
97-
(defaulting to 128MiB, if no byte limits are specified).
96+
chunk sizes to aim for chunks of approximately this size.
97+
98+
Only a single string value indicating a number of bytes can be specified. To
99+
indicate that chunking applies to multiple dimensions, use a dict key of
100+
``...``.
101+
102+
Some examples:
103+
- ``chunks={'time': 100}``: Each chunk will have exactly 100 elements along
104+
the 'time' dimension.
105+
- ``chunks="200MB"``: Create chunks that are approximately 200MB in size.
106+
- ``chunks={'time': -1, ...: "100MB"}``: Chunks should include the full
107+
'time' dimension, and be chunked along other dimensions such that
108+
resulting chunks are approximately 100MiB in size.
98109
99110
Args:
100111
chunks: The desired chunking scheme. Can either be a dictionary mapping
101-
dimension names to chunk sizes, or a single string chunk specification
102-
(e.g., 'auto' or '100MiB') to be applied as the default for all
112+
dimension names to chunk sizes, or a single string/integer chunk
113+
specification (e.g., '100MB') to be applied as the default for all
103114
dimensions. Dimensions not included in the dictionary default to
104-
previous_chunks (if available) or the full size of the dimension.
115+
``previous_chunks`` (if available) or the full size of the dimension. A
116+
dict key of ellipsis (...) can also be used to indicate "all other
117+
dimensions".
105118
template: An xarray.Dataset providing dimension sizes and dtype information,
106119
used for calculating chunk sizes in bytes.
107120
split_vars: If True, chunk size limits are applied per-variable, based on
@@ -113,15 +126,34 @@ def normalize_chunks(
113126
Returns:
114127
A dictionary mapping all dimension names to integer chunk sizes.
115128
"""
116-
if isinstance(chunks, str):
129+
raw_chunks = chunks
130+
131+
if isinstance(chunks, str | int):
132+
if chunks == 'auto':
133+
raise ValueError(
134+
'Unlike Dask, xarray_beam.normalize_chunks() does not support '
135+
"chunks='auto'. Supply an explicit number of bytes instead, e.g., "
136+
"chunks='100MB'."
137+
)
117138
chunks = {k: chunks for k in template.dims}
139+
elif isinstance(chunks, Mapping):
140+
string_chunks = {v for v in chunks.values() if isinstance(v, str)}
141+
if len(string_chunks) > 1:
142+
raise ValueError(
143+
f'cannot provide multiple distinct chunk sizes in bytes: {chunks}'
144+
)
145+
if any(v == 'auto' for v in chunks.values()):
146+
raise ValueError(
147+
'Unlike Dask, xarray_beam.normalize_chunks() does not support '
148+
"'auto' chunk sizes. Supply an explicit number of bytes instead, "
149+
f"e.g., '100MB'. Got {chunks=}"
150+
)
151+
else:
152+
raise TypeError(f'chunks must be a string or a mapping, got {chunks=}')
118153

119-
string_chunks = {v for v in chunks.values() if isinstance(v, str)}
120-
string_chunks.discard('auto')
121-
if len(string_chunks) > 1:
122-
raise ValueError(
123-
f'cannot specify multiple distinct chunk sizes in bytes: {chunks}'
124-
)
154+
if ... in chunks:
155+
default_chunks = chunks[...]
156+
chunks = {k: chunks.get(k, default_chunks) for k in template.dims}
125157

126158
defaults = previous_chunks if previous_chunks else template.sizes
127159
chunks: dict[str, int | str] = {**defaults, **chunks} # pytype: disable=annotation-type-mismatch
@@ -142,19 +174,22 @@ def normalize_chunks(
142174
tuple(previous_chunks[k] for k in chunks) if previous_chunks else None
143175
)
144176

145-
# Note: This values are the same as the dask defaults. Set them explicitly
146-
# here to ensure that Xarray-Beam behavior does not depend on the user's
147-
# dask configuration.
148-
with dask.config.set({
149-
'array.chunk-size': '128MiB',
150-
'array.chunk-size-tolerance': 1.25,
151-
}):
152-
normalized_chunks_tuple = dask.array.api.normalize_chunks(
153-
chunks_tuple,
154-
shape,
155-
dtype=combined_dtype,
156-
previous_chunks=prev_chunks_tuple,
157-
)
177+
# Note: This is the same as the dask default. Set chunk-size-tolerance
178+
# explicitly here to ensure that Xarray-Beam behavior does not depend on the
179+
# user's dask configuration.
180+
with dask.config.set({'array.chunk-size-tolerance': 1.25}):
181+
try:
182+
normalized_chunks_tuple = dask.array.api.normalize_chunks(
183+
chunks_tuple,
184+
shape,
185+
dtype=combined_dtype,
186+
previous_chunks=prev_chunks_tuple,
187+
)
188+
except ValueError as e:
189+
raise ValueError(
190+
f'Invalid input for normalize_chunks: chunks={raw_chunks!r}, '
191+
f'{previous_chunks=}, {template=}'
192+
) from e
158193
return {k: v[0] for k, v in zip(chunks, normalized_chunks_tuple)}
159194

160195

@@ -282,7 +317,9 @@ def __init__(
282317
this dataset's data.
283318
"""
284319
self._template = template
285-
self._chunks = chunks
320+
self._chunks = {
321+
k: min(template.sizes[k], v) for k, v in chunks.items()
322+
}
286323
self._split_vars = split_vars
287324
self._ptransform = ptransform
288325

@@ -357,7 +394,7 @@ def __repr__(self):
357394
def from_xarray(
358395
cls,
359396
source: xarray.Dataset,
360-
chunks: Mapping[str, int | str] | str,
397+
chunks: UnnormalizedChunks,
361398
*,
362399
split_vars: bool = False,
363400
previous_chunks: Mapping[str, int] | None = None,
@@ -384,7 +421,7 @@ def from_zarr(
384421
cls,
385422
path: str,
386423
*,
387-
chunks: Mapping[str, int | str] | str | None = None,
424+
chunks: UnnormalizedChunks | None = None,
388425
split_vars: bool = False,
389426
) -> Dataset:
390427
"""Create an xarray_beam.Dataset from a Zarr store.
@@ -426,8 +463,8 @@ def to_zarr(
426463
path: str,
427464
*,
428465
zarr_chunks_per_shard: Mapping[str, int] | None = None,
429-
zarr_chunks: Mapping[str, int] | None = None,
430-
zarr_shards: Mapping[str, int] | None = None,
466+
zarr_chunks: UnnormalizedChunks | None = None,
467+
zarr_shards: UnnormalizedChunks | None = None,
431468
zarr_format: int | None = None,
432469
) -> beam.PTransform:
433470
"""Write this dataset to a Zarr file.
@@ -461,14 +498,21 @@ def to_zarr(
461498
Returns:
462499
Beam PTransform that writes the dataset to a Zarr file.
463500
"""
501+
if zarr_shards is not None:
502+
zarr_shards = normalize_chunks(
503+
zarr_shards,
504+
self.template,
505+
split_vars=self.split_vars,
506+
previous_chunks=self.chunks,
507+
)
508+
464509
if zarr_chunks_per_shard is not None:
465510
if zarr_chunks is not None:
466511
raise ValueError(
467512
'cannot supply both zarr_chunks_per_shard and zarr_chunks'
468513
)
469514
if zarr_shards is None:
470-
zarr_shards = {}
471-
zarr_shards = {**self.chunks, **zarr_shards}
515+
zarr_shards = self.chunks
472516
zarr_chunks = {}
473517
for dim, existing_chunk_size in zarr_shards.items():
474518
multiple = zarr_chunks_per_shard.get(dim)
@@ -490,9 +534,13 @@ def to_zarr(
490534
raise ValueError('cannot supply zarr_shards without zarr_chunks')
491535
zarr_chunks = {}
492536

493-
zarr_chunks = {**self.chunks, **zarr_chunks}
537+
zarr_chunks = normalize_chunks(
538+
zarr_chunks,
539+
self.template,
540+
split_vars=self.split_vars,
541+
previous_chunks=self.chunks,
542+
)
494543
if zarr_shards is not None:
495-
zarr_shards = {**self.chunks, **zarr_shards}
496544
self._check_shards_or_chunks(zarr_shards, 'shards')
497545
else:
498546
self._check_shards_or_chunks(zarr_chunks, 'chunks')
@@ -537,9 +585,9 @@ def map_blocks(
537585
attempt will be made to infer the template by applying ``func`` to the
538586
existing template, which requires that ``func`` is implemented using
539587
dask compatible operations.
540-
chunks: new chunks sizes for the resulting dataset. If not provided, an
541-
attempt will be made to infer the new chunks based on the existing
542-
chunks, dimensions sizes and the new template.
588+
chunks: explicit new chunks sizes created by applying ``func``. If not
589+
provided, an attempt will be made to infer the new chunks based on the
590+
existing chunks, dimensions sizes and the new template.
543591
544592
Returns:
545593
New Dataset with updated chunks.
@@ -587,7 +635,7 @@ def map_blocks(
587635

588636
def rechunk(
589637
self,
590-
chunks: dict[str, int | str] | str,
638+
chunks: UnnormalizedChunks,
591639
min_mem: int | None = None,
592640
max_mem: int = 2**30,
593641
) -> Dataset:

0 commit comments

Comments
 (0)