-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathaccessors.py
170 lines (129 loc) · 4.85 KB
/
accessors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import cupy as cp
from xarray import Dataset, register_dataarray_accessor, register_dataset_accessor
from xarray.core.pycompat import DuckArrayModule
dask_array_type = DuckArrayModule("dask").type
pint_array_type = DuckArrayModule("pint").type
def _get_datatype(data):
if isinstance(data, dask_array_type):
return isinstance(data._meta, cp.ndarray)
elif isinstance(data, pint_array_type):
return _get_datatype(data.magnitude)
return isinstance(data, cp.ndarray)
@register_dataarray_accessor("cupy")
class CupyDataArrayAccessor:
"""
Access methods for DataArrays using Cupy.
Methods and attributes can be accessed through the `.cupy` attribute.
"""
def __init__(self, da):
self.da = da
@property
def is_cupy(self) -> bool:
"""True if the underlying data is a cupy array."""
return _get_datatype(self.da.data)
def as_cupy(self):
"""
Converts the DataArray's underlying array type to cupy.
For DataArrays which are initially backed by numpy the data
will be immediately cast to cupy and moved to the GPU. In the case
that the data was originally a Dask array each chunk will be moved
to the GPU when the task graph is computed.
Returns
-------
cupy_da: DataArray
DataArray with underlying data cast to cupy.
Examples
--------
>>> import xarray as xr
>>> da = xr.tutorial.load_dataset("air_temperature").air
>>> gda = da.cupy.as_cupy()
>>> type(gda.data)
<class 'cupy.core.core.ndarray'>
"""
return self.da.copy(data=_as_cupy_data(self.da.data))
def as_numpy(self):
"""
Converts the DataArray's underlying array type from cupy to numpy.
Returns
-------
da: DataArray
DataArray with underlying data cast to numpy.
"""
raise NotImplementedError("Please use .as_numpy DataArray method directly.")
def get(self):
return self.da.data.get()
def _as_cupy_data(data):
if isinstance(data, dask_array_type):
return data.map_blocks(cp.asarray)
if isinstance(data, pint_array_type):
from pint import Quantity # pylint: disable=import-outside-toplevel
return Quantity(
_as_cupy_data(data.magnitude),
units=data.units,
)
return cp.asarray(data)
def _as_numpy_data(data):
if isinstance(data, dask_array_type):
return data.map_blocks(lambda block: block.get(), dtype=data._meta.dtype)
if isinstance(data, pint_array_type):
from pint import Quantity # pylint: disable=import-outside-toplevel
return Quantity(
_as_numpy_data(data.magnitude),
units=data.units,
)
return data.get() if isinstance(data, cp.ndarray) else data
@register_dataset_accessor("cupy")
class CupyDatasetAccessor:
"""
Access methods for DataArrays using Cupy.
Methods and attributes can be accessed through the `.cupy` attribute.
"""
def __init__(self, ds):
self.ds = ds
@property
def has_cupy(self) -> bool:
"""True if any data variable contains a cupy array."""
return any([da.cupy.is_cupy for da in self.ds.data_vars.values()])
@property
def is_cupy(self) -> bool:
"""True if all data variables contain cupy arrays."""
return all([da.cupy.is_cupy for da in self.ds.data_vars.values()])
def as_cupy(self):
if not self.is_cupy:
data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()}
return Dataset(
data_vars=data_vars,
coords=self.ds.coords,
attrs=self.ds.attrs,
)
return self.ds
def as_numpy(self):
if self.is_cupy:
data_vars = {var: da.cupy.as_numpy() for var, da in self.ds.data_vars.items()}
return Dataset(
data_vars=data_vars,
coords=self.ds.coords,
attrs=self.ds.attrs,
)
return self.ds
# Attach the `as_cupy` methods to the top level `Dataset` and `Dataarray` objects.
# Would be good to replace this with a less hacky API upstream at some stage where
# libraries like this could register new ``as_`` methods for dispatch.
@register_dataarray_accessor("as_cupy")
def _(da):
"""
Converts the DataArray's underlying array type to cupy.
See :meth:`cupy_xarray.CupyDataArrayAccessor.as_cupy`.
"""
def as_cupy(*args, **kwargs):
return da.cupy.as_cupy(*args, **kwargs)
return as_cupy
@register_dataset_accessor("as_cupy")
def _(ds):
"""
Converts the Dataset's underlying Dataarray's array type to cupy.
See :meth:`cupy_xarray.CupyDatasetAccessor.as_cupy`.
"""
def as_cupy(*args, **kwargs):
return ds.cupy.as_cupy(*args, **kwargs)
return as_cupy