Skip to content

Commit af12604

Browse files
Support rechunking to a frequency. (#9109)
* Support rechunking to a frequency. Closes #7559 * Updates * Fix typing * More typing fixes. * Switch to TimeResampler objects * small fix * Add whats-new * More test * fix docs * fix * Update doc/user-guide/dask.rst Co-authored-by: Spencer Clark <[email protected]> --------- Co-authored-by: Spencer Clark <[email protected]>
1 parent 53c5634 commit af12604

File tree

9 files changed

+148
-31
lines changed

9 files changed

+148
-31
lines changed

doc/user-guide/dask.rst

+16
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ loaded into Dask or not:
296296
Automatic parallelization with ``apply_ufunc`` and ``map_blocks``
297297
-----------------------------------------------------------------
298298

299+
.. tip::
300+
301+
Some problems can become embarassingly parallel and thus easy to parallelize
302+
automatically by rechunking to a frequency, e.g. ``ds.chunk(time=TimeResampler("YE"))``.
303+
See :py:meth:`Dataset.chunk` for more.
304+
299305
Almost all of xarray's built-in operations work on Dask arrays. If you want to
300306
use a function that isn't wrapped by xarray, and have it applied in parallel on
301307
each block of your xarray object, you have three options:
@@ -551,6 +557,16 @@ larger chunksizes.
551557

552558
Check out the `dask documentation on chunks <https://docs.dask.org/en/latest/array-chunks.html>`_.
553559

560+
.. tip::
561+
562+
Many time domain problems become amenable to an embarassingly parallel or blockwise solution
563+
(e.g. using :py:func:`xarray.map_blocks`, :py:func:`dask.array.map_blocks`, or
564+
:py:func:`dask.array.blockwise`) by rechunking to a frequency along the time dimension.
565+
Provide :py:class:`xarray.groupers.TimeResampler` objects to :py:meth:`Dataset.chunk` to do so.
566+
For example ``ds.chunk(time=TimeResampler("MS"))`` will set the chunks so that a month of
567+
data is contained in one chunk. The resulting chunk sizes need not be uniform, depending on
568+
the frequency of the data, and the calendar.
569+
554570

555571
Optimization Tips
556572
-----------------

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ New Features
3030
`grouper design doc <https://github.com/pydata/xarray/blob/main/design_notes/grouper_objects.md>`_ for more.
3131
(:issue:`6610`, :pull:`8840`).
3232
By `Deepak Cherian <https://github.com/dcherian>`_.
33-
- Allow per-variable specification of ``mask_and_scale``, ``decode_times``, ``decode_timedelta``
33+
- Allow rechunking to a frequency using ``Dataset.chunk(time=TimeResampler("YE"))`` syntax. (:issue:`7559`, :pull:`9109`)
34+
Such rechunking allows many time domain analyses to be executed in an embarassingly parallel fashion.
35+
By `Deepak Cherian <https://github.com/dcherian>`_.
36+
- Allow per-variable specification of ```mask_and_scale``, ``decode_times``, ``decode_timedelta``
3437
``use_cftime`` and ``concat_characters`` params in :py:func:`~xarray.open_dataset` (:pull:`9218`).
3538
By `Mathijs Verhaegh <https://github.com/Ostheer>`_.
3639
- Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`).

xarray/core/dataarray.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@
107107
ReindexMethodOptions,
108108
Self,
109109
SideOptions,
110-
T_Chunks,
110+
T_ChunkDimFreq,
111+
T_ChunksFreq,
111112
T_Xarray,
112113
)
113114
from xarray.core.weighted import DataArrayWeighted
@@ -1351,15 +1352,15 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:
13511352
@_deprecate_positional_args("v2023.10.0")
13521353
def chunk(
13531354
self,
1354-
chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
1355+
chunks: T_ChunksFreq = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
13551356
*,
13561357
name_prefix: str = "xarray-",
13571358
token: str | None = None,
13581359
lock: bool = False,
13591360
inline_array: bool = False,
13601361
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
13611362
from_array_kwargs=None,
1362-
**chunks_kwargs: Any,
1363+
**chunks_kwargs: T_ChunkDimFreq,
13631364
) -> Self:
13641365
"""Coerce this array's data into a dask arrays with the given chunks.
13651366
@@ -1371,11 +1372,13 @@ def chunk(
13711372
sizes along that dimension will not be updated; non-dask arrays will be
13721373
converted into dask arrays with a single block.
13731374
1375+
Along datetime-like dimensions, a pandas frequency string is also accepted.
1376+
13741377
Parameters
13751378
----------
1376-
chunks : int, "auto", tuple of int or mapping of Hashable to int, optional
1379+
chunks : int, "auto", tuple of int or mapping of hashable to int or a pandas frequency string, optional
13771380
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or
1378-
``{"x": 5, "y": 5}``.
1381+
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": "YE"}``.
13791382
name_prefix : str, optional
13801383
Prefix for the name of the new dask array.
13811384
token : str, optional
@@ -1410,29 +1413,30 @@ def chunk(
14101413
xarray.unify_chunks
14111414
dask.array.from_array
14121415
"""
1416+
chunk_mapping: T_ChunksFreq
14131417
if chunks is None:
14141418
warnings.warn(
14151419
"None value for 'chunks' is deprecated. "
14161420
"It will raise an error in the future. Use instead '{}'",
14171421
category=FutureWarning,
14181422
)
1419-
chunks = {}
1423+
chunk_mapping = {}
14201424

14211425
if isinstance(chunks, (float, str, int)):
14221426
# ignoring type; unclear why it won't accept a Literal into the value.
1423-
chunks = dict.fromkeys(self.dims, chunks)
1427+
chunk_mapping = dict.fromkeys(self.dims, chunks)
14241428
elif isinstance(chunks, (tuple, list)):
14251429
utils.emit_user_level_warning(
14261430
"Supplying chunks as dimension-order tuples is deprecated. "
14271431
"It will raise an error in the future. Instead use a dict with dimension names as keys.",
14281432
category=DeprecationWarning,
14291433
)
1430-
chunks = dict(zip(self.dims, chunks))
1434+
chunk_mapping = dict(zip(self.dims, chunks))
14311435
else:
1432-
chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
1436+
chunk_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
14331437

14341438
ds = self._to_temp_dataset().chunk(
1435-
chunks,
1439+
chunk_mapping,
14361440
name_prefix=name_prefix,
14371441
token=token,
14381442
lock=lock,

xarray/core/dataset.py

+55-14
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
QuantileMethods,
9494
Self,
9595
T_ChunkDim,
96-
T_Chunks,
96+
T_ChunksFreq,
9797
T_DataArray,
9898
T_DataArrayOrSet,
9999
T_Dataset,
@@ -162,6 +162,7 @@
162162
QueryParserOptions,
163163
ReindexMethodOptions,
164164
SideOptions,
165+
T_ChunkDimFreq,
165166
T_Xarray,
166167
)
167168
from xarray.core.weighted import DatasetWeighted
@@ -283,18 +284,17 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
283284

284285

285286
def _maybe_chunk(
286-
name,
287-
var,
288-
chunks,
287+
name: Hashable,
288+
var: Variable,
289+
chunks: Mapping[Any, T_ChunkDim] | None,
289290
token=None,
290291
lock=None,
291-
name_prefix="xarray-",
292-
overwrite_encoded_chunks=False,
293-
inline_array=False,
292+
name_prefix: str = "xarray-",
293+
overwrite_encoded_chunks: bool = False,
294+
inline_array: bool = False,
294295
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
295296
from_array_kwargs=None,
296-
):
297-
297+
) -> Variable:
298298
from xarray.namedarray.daskmanager import DaskManager
299299

300300
if chunks is not None:
@@ -2648,14 +2648,14 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
26482648

26492649
def chunk(
26502650
self,
2651-
chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
2651+
chunks: T_ChunksFreq = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
26522652
name_prefix: str = "xarray-",
26532653
token: str | None = None,
26542654
lock: bool = False,
26552655
inline_array: bool = False,
26562656
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
26572657
from_array_kwargs=None,
2658-
**chunks_kwargs: T_ChunkDim,
2658+
**chunks_kwargs: T_ChunkDimFreq,
26592659
) -> Self:
26602660
"""Coerce all arrays in this dataset into dask arrays with the given
26612661
chunks.
@@ -2667,11 +2667,13 @@ def chunk(
26672667
sizes along that dimension will not be updated; non-dask arrays will be
26682668
converted into dask arrays with a single block.
26692669
2670+
Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
2671+
26702672
Parameters
26712673
----------
2672-
chunks : int, tuple of int, "auto" or mapping of hashable to int, optional
2674+
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
26732675
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
2674-
``{"x": 5, "y": 5}``.
2676+
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
26752677
name_prefix : str, default: "xarray-"
26762678
Prefix for the name of any new dask arrays.
26772679
token : str, optional
@@ -2706,6 +2708,9 @@ def chunk(
27062708
xarray.unify_chunks
27072709
dask.array.from_array
27082710
"""
2711+
from xarray.core.dataarray import DataArray
2712+
from xarray.core.groupers import TimeResampler
2713+
27092714
if chunks is None and not chunks_kwargs:
27102715
warnings.warn(
27112716
"None value for 'chunks' is deprecated. "
@@ -2731,6 +2736,42 @@ def chunk(
27312736
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
27322737
)
27332738

2739+
def _resolve_frequency(
2740+
name: Hashable, resampler: TimeResampler
2741+
) -> tuple[int, ...]:
2742+
variable = self._variables.get(name, None)
2743+
if variable is None:
2744+
raise ValueError(
2745+
f"Cannot chunk by resampler {resampler!r} for virtual variables."
2746+
)
2747+
elif not _contains_datetime_like_objects(variable):
2748+
raise ValueError(
2749+
f"chunks={resampler!r} only supported for datetime variables. "
2750+
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
2751+
)
2752+
2753+
assert variable.ndim == 1
2754+
chunks: tuple[int, ...] = tuple(
2755+
DataArray(
2756+
np.ones(variable.shape, dtype=int),
2757+
dims=(name,),
2758+
coords={name: variable},
2759+
)
2760+
.resample({name: resampler})
2761+
.sum()
2762+
.data.tolist()
2763+
)
2764+
return chunks
2765+
2766+
chunks_mapping_ints: Mapping[Any, T_ChunkDim] = {
2767+
name: (
2768+
_resolve_frequency(name, chunks)
2769+
if isinstance(chunks, TimeResampler)
2770+
else chunks
2771+
)
2772+
for name, chunks in chunks_mapping.items()
2773+
}
2774+
27342775
chunkmanager = guess_chunkmanager(chunked_array_type)
27352776
if from_array_kwargs is None:
27362777
from_array_kwargs = {}
@@ -2739,7 +2780,7 @@ def chunk(
27392780
k: _maybe_chunk(
27402781
k,
27412782
v,
2742-
chunks_mapping,
2783+
chunks_mapping_ints,
27432784
token,
27442785
lock,
27452786
name_prefix,

xarray/core/groupers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
264264
)
265265

266266

267-
@dataclass
267+
@dataclass(repr=False)
268268
class TimeResampler(Resampler):
269269
"""
270270
Grouper object specialized to resampling the time coordinate.

xarray/core/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from xarray.core.coordinates import Coordinates
4040
from xarray.core.dataarray import DataArray
4141
from xarray.core.dataset import Dataset
42+
from xarray.core.groupers import TimeResampler
4243
from xarray.core.indexes import Index, Indexes
4344
from xarray.core.utils import Frozen
4445
from xarray.core.variable import Variable
@@ -191,6 +192,8 @@ def copy(
191192
# FYI in some cases we don't allow `None`, which this doesn't take account of.
192193
# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
193194
T_ChunkDim: TypeAlias = Union[str, int, Literal["auto"], None, tuple[int, ...]]
195+
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
196+
T_ChunksFreq: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDimFreq]]
194197
# We allow the tuple form of this (though arguably we could transition to named dims only)
195198
T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]]
196199
T_NormalizedChunks = tuple[tuple[int, ...], ...]

xarray/core/variable.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Hashable, Mapping, Sequence
99
from datetime import timedelta
1010
from functools import partial
11-
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast
11+
from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast
1212

1313
import numpy as np
1414
import pandas as pd
@@ -63,6 +63,7 @@
6363
PadReflectOptions,
6464
QuantileMethods,
6565
Self,
66+
T_Chunks,
6667
T_DuckArray,
6768
)
6869
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
@@ -2522,7 +2523,7 @@ def _to_dense(self) -> Variable:
25222523

25232524
def chunk( # type: ignore[override]
25242525
self,
2525-
chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {},
2526+
chunks: T_Chunks = {},
25262527
name: str | None = None,
25272528
lock: bool | None = None,
25282529
inline_array: bool | None = None,

xarray/namedarray/core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
if TYPE_CHECKING:
5454
from numpy.typing import ArrayLike, NDArray
5555

56-
from xarray.core.types import Dims
56+
from xarray.core.types import Dims, T_Chunks
5757
from xarray.namedarray._typing import (
5858
Default,
5959
_AttrsLike,
@@ -748,7 +748,7 @@ def sizes(self) -> dict[_Dim, _IntOrUnknown]:
748748

749749
def chunk(
750750
self,
751-
chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {},
751+
chunks: T_Chunks = {},
752752
chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None,
753753
from_array_kwargs: Any = None,
754754
**chunks_kwargs: Any,
@@ -839,7 +839,7 @@ def chunk(
839839
ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment]
840840

841841
if is_dict_like(chunks):
842-
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) # type: ignore[assignment]
842+
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))
843843

844844
data_chunked = chunkmanager.from_array(ndata, chunks, **from_array_kwargs) # type: ignore[arg-type]
845845

0 commit comments

Comments
 (0)