38
38
import os .path
39
39
import tempfile
40
40
import textwrap
41
+ import types
41
42
from typing import Callable , Literal
42
43
43
44
import apache_beam as beam
@@ -71,16 +72,19 @@ def _to_human_size(nbytes: int) -> str:
71
72
return f'{ _at_least_two_digits (nbytes )} EB'
72
73
73
74
75
+ UnnormalizedChunks = Mapping [str | types .EllipsisType , int | str ] | int | str
76
+
77
+
74
78
def normalize_chunks (
75
- chunks : Mapping [ str , int | str ] | str ,
79
+ chunks : UnnormalizedChunks ,
76
80
template : xarray .Dataset ,
77
81
split_vars : bool = False ,
78
82
previous_chunks : Mapping [str , int ] | None = None ,
79
83
) -> dict [str , int ]:
80
84
"""Normalize chunks for a xarray.Dataset.
81
85
82
- This function interprets various chunk specifications (e.g., -1, 'auto',
83
- byte-strings ) and returns a dictionary mapping dimension names to
86
+ This function interprets various chunk specifications (e.g., integer sizes or
87
+ numbers of bytes ) and returns a dictionary mapping dimension names to
84
88
concrete integer chunk sizes. It uses ``dask.array.api.normalize_chunks``
85
89
under the hood.
86
90
@@ -89,19 +93,28 @@ def normalize_chunks(
89
93
dimension.
90
94
- An integer: the exact chunk size for this dimension.
91
95
- A byte-string (e.g., "64MiB", "1GB"): indicates that dask should pick
92
- chunk sizes to aim for chunks of approximately this size. If byte limits
93
- are specified for multiple dimensions, they must be consistent (i.e.,
94
- parse to the same number of bytes).
95
- - ``'auto'``: chunks will be automatically determined for all 'auto'
96
- dimensions to ensure chunks are approximately the target number of bytes
97
- (defaulting to 128MiB, if no byte limits are specified).
96
+ chunk sizes to aim for chunks of approximately this size.
97
+
98
+ Only a single string value indicating a number of bytes can be specified. To
99
+ indicate that chunking applies to multiple dimensions, use a dict key of
100
+ ``...``.
101
+
102
+ Some examples:
103
+ - ``chunks={'time': 100}``: Each chunk will have exactly 100 elements along
104
+ the 'time' dimension.
105
+ - ``chunks="200MB"``: Create chunks that are approximately 200MB in size.
106
+ - ``chunks={'time': -1, ...: "100MB"}``: Chunks should include the full
107
+ 'time' dimension, and be chunked along other dimensions such that
108
+ resulting chunks are approximately 100MiB in size.
98
109
99
110
Args:
100
111
chunks: The desired chunking scheme. Can either be a dictionary mapping
101
- dimension names to chunk sizes, or a single string chunk specification
102
- (e.g., 'auto' or '100MiB ') to be applied as the default for all
112
+ dimension names to chunk sizes, or a single string/integer chunk
113
+ specification (e.g., '100MB ') to be applied as the default for all
103
114
dimensions. Dimensions not included in the dictionary default to
104
- previous_chunks (if available) or the full size of the dimension.
115
+ ``previous_chunks`` (if available) or the full size of the dimension. A
116
+ dict key of ellipsis (...) can also be used to indicate "all other
117
+ dimensions".
105
118
template: An xarray.Dataset providing dimension sizes and dtype information,
106
119
used for calculating chunk sizes in bytes.
107
120
split_vars: If True, chunk size limits are applied per-variable, based on
@@ -113,15 +126,34 @@ def normalize_chunks(
113
126
Returns:
114
127
A dictionary mapping all dimension names to integer chunk sizes.
115
128
"""
116
- if isinstance (chunks , str ):
129
+ raw_chunks = chunks
130
+
131
+ if isinstance (chunks , str | int ):
132
+ if chunks == 'auto' :
133
+ raise ValueError (
134
+ 'Unlike Dask, xarray_beam.normalize_chunks() does not support '
135
+ "chunks='auto'. Supply an explicit number of bytes instead, e.g., "
136
+ "chunks='100MB'."
137
+ )
117
138
chunks = {k : chunks for k in template .dims }
139
+ elif isinstance (chunks , Mapping ):
140
+ string_chunks = {v for v in chunks .values () if isinstance (v , str )}
141
+ if len (string_chunks ) > 1 :
142
+ raise ValueError (
143
+ f'cannot provide multiple distinct chunk sizes in bytes: { chunks } '
144
+ )
145
+ if any (v == 'auto' for v in chunks .values ()):
146
+ raise ValueError (
147
+ 'Unlike Dask, xarray_beam.normalize_chunks() does not support '
148
+ "'auto' chunk sizes. Supply an explicit number of bytes instead, "
149
+ f"e.g., '100MB'. Got { chunks = } "
150
+ )
151
+ else :
152
+ raise TypeError (f'chunks must be a string or a mapping, got { chunks = } ' )
118
153
119
- string_chunks = {v for v in chunks .values () if isinstance (v , str )}
120
- string_chunks .discard ('auto' )
121
- if len (string_chunks ) > 1 :
122
- raise ValueError (
123
- f'cannot specify multiple distinct chunk sizes in bytes: { chunks } '
124
- )
154
+ if ... in chunks :
155
+ default_chunks = chunks [...]
156
+ chunks = {k : chunks .get (k , default_chunks ) for k in template .dims }
125
157
126
158
defaults = previous_chunks if previous_chunks else template .sizes
127
159
chunks : dict [str , int | str ] = {** defaults , ** chunks } # pytype: disable=annotation-type-mismatch
@@ -142,19 +174,22 @@ def normalize_chunks(
142
174
tuple (previous_chunks [k ] for k in chunks ) if previous_chunks else None
143
175
)
144
176
145
- # Note: This values are the same as the dask defaults. Set them explicitly
146
- # here to ensure that Xarray-Beam behavior does not depend on the user's
147
- # dask configuration.
148
- with dask .config .set ({
149
- 'array.chunk-size' : '128MiB' ,
150
- 'array.chunk-size-tolerance' : 1.25 ,
151
- }):
152
- normalized_chunks_tuple = dask .array .api .normalize_chunks (
153
- chunks_tuple ,
154
- shape ,
155
- dtype = combined_dtype ,
156
- previous_chunks = prev_chunks_tuple ,
157
- )
177
+ # Note: This is the same as the dask default. Set chunk-size-tolerance
178
+ # explicitly here to ensure that Xarray-Beam behavior does not depend on the
179
+ # user's dask configuration.
180
+ with dask .config .set ({'array.chunk-size-tolerance' : 1.25 }):
181
+ try :
182
+ normalized_chunks_tuple = dask .array .api .normalize_chunks (
183
+ chunks_tuple ,
184
+ shape ,
185
+ dtype = combined_dtype ,
186
+ previous_chunks = prev_chunks_tuple ,
187
+ )
188
+ except ValueError as e :
189
+ raise ValueError (
190
+ f'Invalid input for normalize_chunks: chunks={ raw_chunks !r} , '
191
+ f'{ previous_chunks = } , { template = } '
192
+ ) from e
158
193
return {k : v [0 ] for k , v in zip (chunks , normalized_chunks_tuple )}
159
194
160
195
@@ -282,7 +317,9 @@ def __init__(
282
317
this dataset's data.
283
318
"""
284
319
self ._template = template
285
- self ._chunks = chunks
320
+ self ._chunks = {
321
+ k : min (template .sizes [k ], v ) for k , v in chunks .items ()
322
+ }
286
323
self ._split_vars = split_vars
287
324
self ._ptransform = ptransform
288
325
@@ -357,7 +394,7 @@ def __repr__(self):
357
394
def from_xarray (
358
395
cls ,
359
396
source : xarray .Dataset ,
360
- chunks : Mapping [ str , int | str ] | str ,
397
+ chunks : UnnormalizedChunks ,
361
398
* ,
362
399
split_vars : bool = False ,
363
400
previous_chunks : Mapping [str , int ] | None = None ,
@@ -384,7 +421,7 @@ def from_zarr(
384
421
cls ,
385
422
path : str ,
386
423
* ,
387
- chunks : Mapping [ str , int | str ] | str | None = None ,
424
+ chunks : UnnormalizedChunks | None = None ,
388
425
split_vars : bool = False ,
389
426
) -> Dataset :
390
427
"""Create an xarray_beam.Dataset from a Zarr store.
@@ -426,8 +463,8 @@ def to_zarr(
426
463
path : str ,
427
464
* ,
428
465
zarr_chunks_per_shard : Mapping [str , int ] | None = None ,
429
- zarr_chunks : Mapping [ str , int ] | None = None ,
430
- zarr_shards : Mapping [ str , int ] | None = None ,
466
+ zarr_chunks : UnnormalizedChunks | None = None ,
467
+ zarr_shards : UnnormalizedChunks | None = None ,
431
468
zarr_format : int | None = None ,
432
469
) -> beam .PTransform :
433
470
"""Write this dataset to a Zarr file.
@@ -461,14 +498,21 @@ def to_zarr(
461
498
Returns:
462
499
Beam PTransform that writes the dataset to a Zarr file.
463
500
"""
501
+ if zarr_shards is not None :
502
+ zarr_shards = normalize_chunks (
503
+ zarr_shards ,
504
+ self .template ,
505
+ split_vars = self .split_vars ,
506
+ previous_chunks = self .chunks ,
507
+ )
508
+
464
509
if zarr_chunks_per_shard is not None :
465
510
if zarr_chunks is not None :
466
511
raise ValueError (
467
512
'cannot supply both zarr_chunks_per_shard and zarr_chunks'
468
513
)
469
514
if zarr_shards is None :
470
- zarr_shards = {}
471
- zarr_shards = {** self .chunks , ** zarr_shards }
515
+ zarr_shards = self .chunks
472
516
zarr_chunks = {}
473
517
for dim , existing_chunk_size in zarr_shards .items ():
474
518
multiple = zarr_chunks_per_shard .get (dim )
@@ -490,9 +534,13 @@ def to_zarr(
490
534
raise ValueError ('cannot supply zarr_shards without zarr_chunks' )
491
535
zarr_chunks = {}
492
536
493
- zarr_chunks = {** self .chunks , ** zarr_chunks }
537
+ zarr_chunks = normalize_chunks (
538
+ zarr_chunks ,
539
+ self .template ,
540
+ split_vars = self .split_vars ,
541
+ previous_chunks = self .chunks ,
542
+ )
494
543
if zarr_shards is not None :
495
- zarr_shards = {** self .chunks , ** zarr_shards }
496
544
self ._check_shards_or_chunks (zarr_shards , 'shards' )
497
545
else :
498
546
self ._check_shards_or_chunks (zarr_chunks , 'chunks' )
@@ -537,9 +585,9 @@ def map_blocks(
537
585
attempt will be made to infer the template by applying ``func`` to the
538
586
existing template, which requires that ``func`` is implemented using
539
587
dask compatible operations.
540
- chunks: new chunks sizes for the resulting dataset . If not provided, an
541
- attempt will be made to infer the new chunks based on the existing
542
- chunks, dimensions sizes and the new template.
588
+ chunks: explicit new chunks sizes created by applying ``func`` . If not
589
+ provided, an attempt will be made to infer the new chunks based on the
590
+ existing chunks, dimensions sizes and the new template.
543
591
544
592
Returns:
545
593
New Dataset with updated chunks.
@@ -587,7 +635,7 @@ def map_blocks(
587
635
588
636
def rechunk (
589
637
self ,
590
- chunks : dict [ str , int | str ] | str ,
638
+ chunks : UnnormalizedChunks ,
591
639
min_mem : int | None = None ,
592
640
max_mem : int = 2 ** 30 ,
593
641
) -> Dataset :
0 commit comments