Skip to content

Commit 4df552c

Browse files
authored
Cupy zonal (#639)
* adds profiling script and notebook for zonal stats * stab functions and todos * advances cupy implementation * progresses cupy zonal * progresses cupy zonal stats * implementing cupy version * First working cupy zonal, needs optimization * adds profiling module * detailed profiling, cupy is 4x faster, can be further improved * prepares optimized implementation * debugs optimized implementation * adds -profile option and improves performance of cupy * optimized cupy zonal stats * only one transfer for the stats results * need to use cupy.prof for accurate profiling * accurate cupy profiling * auto configuring total time * minor clean up * improved result reporting * improving readability * improving readability * adds script to test custom stat functions * adds script to test custom stat functions * fix minor typo * adds custom functions * fixes minor issues in custom kernel stat * saving to timings dir * checking if profiling directory exists * removes profiling-related commands * adds jupyter notebook benchmark for zonal stats * completes zonal stats benchmark * clean up * bug fixes * removes nodata_zones arg from cupy * adds new benchmark for zonal stats * tested asv benchmarks * untracks not needed files * fixed flake8 warnings * code readability improvements
1 parent 224e432 commit 4df552c

File tree

2 files changed

+206
-23
lines changed

2 files changed

+206
-23
lines changed

benchmarks/benchmarks/zonal.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import xarray as xr
2+
import numpy as np
3+
4+
from xrspatial import zonal
5+
from xrspatial.utils import has_cuda
6+
from .common import get_xr_dataarray
7+
8+
9+
def create_arr(data=None, H=10, W=10, backend='numpy'):
10+
assert(backend in ['numpy', 'cupy', 'dask'])
11+
if data is None:
12+
data = np.zeros((H, W), dtype=np.float32)
13+
raster = xr.DataArray(data, dims=['y', 'x'])
14+
15+
if has_cuda() and 'cupy' in backend:
16+
import cupy
17+
raster.data = cupy.asarray(raster.data)
18+
19+
if 'dask' in backend:
20+
import dask.array as da
21+
raster.data = da.from_array(raster.data, chunks=(10, 10))
22+
23+
return raster
24+
25+
26+
class Zonal:
27+
# Note that rtxpy hillshade includes shadow calculations so timings are
28+
# not comparable with numpy and cupy hillshade.
29+
params = ([400, 1600, 3200], [2, 4, 8], ["numpy", "cupy"])
30+
param_names = ("raster_dim", "zone_dim", "backend")
31+
32+
def setup(self, raster_dim, zone_dim, backend):
33+
W = H = raster_dim
34+
zW = zH = zone_dim
35+
# Make sure that the raster dim is multiple of the zones dim
36+
assert(W % zW == 0)
37+
assert(H % zH == 0)
38+
# initialize the values raster
39+
self.values = get_xr_dataarray((H, W), backend)
40+
41+
# initialize the zones raster
42+
zones = xr.DataArray(np.zeros((H, W)))
43+
hstep = H//zH
44+
wstep = W//zW
45+
for i in range(zH):
46+
for j in range(zW):
47+
zones[i * hstep: (i+1)*hstep, j*wstep: (j+1)*wstep] = i*zW + j
48+
49+
''' zones now looks like this
50+
>>> zones = np.array([
51+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
52+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
53+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
54+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
55+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
56+
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
57+
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
58+
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
59+
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
60+
[2, 2, 2, 2, 2, 3, 3, 3, 3, 3]])
61+
'''
62+
63+
self.zones = create_arr(zones, backend=backend)
64+
65+
# Now setup the custom stat funcs
66+
if backend == 'cupy':
67+
import cupy
68+
l2normKernel = cupy.ReductionKernel(
69+
in_params='T x', out_params='float64 y',
70+
map_expr='x*x', reduce_expr='a+b',
71+
post_map_expr='y = sqrt(a)',
72+
identity='0', name='l2normKernel'
73+
)
74+
self.custom_stats = {
75+
'double_sum': lambda val: val.sum()*2,
76+
'l2norm': lambda val: np.sqrt(cupy.sum(val * val)),
77+
'l2normKernel': lambda val: l2normKernel(val)
78+
}
79+
else:
80+
from xrspatial.utils import ngjit
81+
82+
@ngjit
83+
def l2normKernel(arr):
84+
acc = 0
85+
for x in arr:
86+
acc += x * x
87+
return np.sqrt(acc)
88+
89+
self.custom_stats = {
90+
'double_sum': lambda val: val.sum()*2,
91+
'l2norm': lambda val: np.sqrt(np.sum(val * val)),
92+
'l2normKernel': lambda val: l2normKernel(val)
93+
}
94+
95+
def time_zonal_stats_default(self, raster_dim, zone_dim, backend):
96+
zonal.stats(zones=self.zones, values=self.values)
97+
98+
def time_zonal_stats_custom(self, raster_dim, zone_dim, backend):
99+
zonal.stats(zones=self.zones, values=self.values,
100+
stats_funcs=self.custom_stats)

xrspatial/zonal.py

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
1+
# standard library
12
from math import sqrt
23
from typing import Optional, Callable, Union, Dict, List
34

5+
# 3rd-party
6+
import dask.array as da
7+
import dask.dataframe as dd
8+
from dask import delayed
9+
410
import numpy as np
511
import pandas as pd
612
import xarray as xr
713
from xarray import DataArray
814

9-
import dask.array as da
10-
import dask.dataframe as dd
11-
from dask import delayed
15+
try:
16+
import cupy
17+
except ImportError:
18+
class cupy(object):
19+
ndarray = False
1220

13-
from xrspatial.utils import ngjit, validate_arrays
21+
# local modules
22+
from xrspatial.utils import ngjit
23+
from xrspatial.utils import validate_arrays
1424
from xrspatial.utils import ArrayTypeFunctionMapping
1525
from xrspatial.utils import not_implemented_func
1626

17-
1827
TOTAL_COUNT = '_total_count'
1928

2029

2130
def _stats_count(data):
2231
if isinstance(data, np.ndarray):
2332
# numpy case
2433
stats_count = np.ma.count(data)
34+
elif isinstance(data, cupy.ndarray):
35+
# cupy case
36+
stats_count = np.prod(data.shape)
2537
else:
2638
# dask case
2739
stats_count = data.size - da.ma.getmaskarray(data).sum()
@@ -56,9 +68,9 @@ def _stats_count(data):
5668
sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0),
5769
squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2,
5870
)
59-
_dask_mean = lambda sums, counts: sums / counts # noqa
60-
_dask_std = lambda sum_squares, squared_sum, n: np.sqrt((sum_squares - squared_sum/n) / n) # noqa
61-
_dask_var = lambda sum_squares, squared_sum, n: (sum_squares - squared_sum/n) / n # noqa
71+
def _dask_mean(sums, counts): return sums / counts # noqa
72+
def _dask_std(sum_squares, squared_sum, n): return np.sqrt((sum_squares - squared_sum/n) / n) # noqa
73+
def _dask_var(sum_squares, squared_sum, n): return (sum_squares - squared_sum/n) / n # noqa
6274

6375

6476
@ngjit
@@ -282,6 +294,81 @@ def _stats_numpy(
282294
return stats_df
283295

284296

297+
def _stats_cupy(
298+
orig_zones: xr.DataArray,
299+
orig_values: xr.DataArray,
300+
zone_ids: List[Union[int, float]],
301+
stats_funcs: Dict,
302+
nodata_values: Union[int, float],
303+
) -> pd.DataFrame:
304+
305+
# TODO add support for 3D input
306+
if len(orig_values.shape) > 2:
307+
raise TypeError('3D inputs not supported for cupy backend')
308+
309+
zones = cupy.ravel(orig_zones)
310+
values = cupy.ravel(orig_values)
311+
312+
sorted_indices = cupy.argsort(zones)
313+
314+
sorted_zones = zones[sorted_indices]
315+
values_by_zone = values[sorted_indices]
316+
317+
# filter out values that are non-finite or values equal to nodata_values
318+
if nodata_values:
319+
filter_values = cupy.isfinite(values_by_zone) & (
320+
values_by_zone != nodata_values)
321+
else:
322+
filter_values = cupy.isfinite(values_by_zone)
323+
values_by_zone = values_by_zone[filter_values]
324+
sorted_zones = sorted_zones[filter_values]
325+
326+
# Now I need to find the unique zones, and zone breaks
327+
unique_zones, unique_index = cupy.unique(sorted_zones, return_index=True)
328+
329+
# Transfer to the host
330+
unique_index = unique_index.get()
331+
if zone_ids is None:
332+
unique_zones = unique_zones.get()
333+
else:
334+
unique_zones = zone_ids
335+
# unique_zones = list(map(_to_int, unique_zones))
336+
unique_zones = np.asarray(unique_zones)
337+
338+
# stats columns
339+
stats_dict = {'zone': []}
340+
for stats in stats_funcs:
341+
stats_dict[stats] = []
342+
343+
for i in range(len(unique_zones)):
344+
zone_id = unique_zones[i]
345+
# skip zone_id == nodata_zones, and non-finite zone ids
346+
if not np.isfinite(zone_id):
347+
continue
348+
349+
stats_dict['zone'].append(zone_id)
350+
# extract zone_values
351+
if i < len(unique_zones) - 1:
352+
zone_values = values_by_zone[unique_index[i]:unique_index[i+1]]
353+
else:
354+
zone_values = values_by_zone[unique_index[i]:]
355+
356+
# apply stats on the zone data
357+
for j, stats in enumerate(stats_funcs):
358+
stats_func = stats_funcs.get(stats)
359+
if not callable(stats_func):
360+
raise ValueError(stats)
361+
result = stats_func(zone_values)
362+
363+
assert(len(result.shape) == 0)
364+
365+
stats_dict[stats].append(cupy.float(result))
366+
367+
stats_df = pd.DataFrame(stats_dict)
368+
stats_df.set_index("zone")
369+
return stats_df
370+
371+
285372
def stats(
286373
zones: xr.DataArray,
287374
values: xr.DataArray,
@@ -461,13 +548,11 @@ def stats(
461548
if isinstance(stats_funcs, list):
462549
# create a dict of stats
463550
stats_funcs_dict = {}
464-
465551
for stats in stats_funcs:
466552
func = _DEFAULT_STATS.get(stats, None)
467553
if func is None:
468554
err_str = f"Invalid stat name. {stats} option not supported."
469555
raise ValueError(err_str)
470-
471556
stats_funcs_dict[stats] = func
472557

473558
elif isinstance(stats_funcs, dict):
@@ -476,9 +561,7 @@ def stats(
476561
mapper = ArrayTypeFunctionMapping(
477562
numpy_func=_stats_numpy,
478563
dask_func=_stats_dask_numpy,
479-
cupy_func=lambda *args: not_implemented_func(
480-
*args, messages='stats() does not support cupy backed DataArray'
481-
),
564+
cupy_func=_stats_cupy,
482565
dask_cupy_func=lambda *args: not_implemented_func(
483566
*args, messages='stats() does not support dask with cupy backed DataArray' # noqa
484567
),
@@ -841,13 +924,13 @@ def crosstab(
841924
>>> df = crosstab(zones=zones_dask, values=values_dask)
842925
>>> print(df)
843926
Dask DataFrame Structure:
844-
zone 0.0 10.0 20.0 30.0 40.0 50.0
927+
zone 0.0 10.0 20.0 30.0 40.0 50.0
845928
npartitions=5
846-
0 float64 int64 int64 int64 int64 int64 int64
847-
1 ... ... ... ... ... ... ...
848-
... ... ... ... ... ... ... ...
849-
4 ... ... ... ... ... ... ...
850-
5 ... ... ... ... ... ... ...
929+
0 float64 int64 int64 int64 int64 int64 int64
930+
1 ... ... ... ... ... ... ...
931+
... ... ... ... ... ... ... ...
932+
4 ... ... ... ... ... ... ...
933+
5 ... ... ... ... ... ... ...
851934
Dask Name: astype, 1186 tasks
852935
>>> print(dask_df.compute)
853936
zone 0.0 10.0 20.0 30.0 40.0 50.0
@@ -1214,7 +1297,7 @@ def _area_connectivity(data, n=4):
12141297
src_window[4] = data[min(y + 1, rows - 1), x]
12151298
src_window[5] = data[max(y - 1, 0), min(x + 1, cols - 1)]
12161299
src_window[6] = data[y, min(x + 1, cols - 1)]
1217-
src_window[7] = data[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
1300+
src_window[7] = data[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
12181301

12191302
area_window[0] = out[max(y - 1, 0), max(x - 1, 0)]
12201303
area_window[1] = out[y, max(x - 1, 0)]
@@ -1223,7 +1306,7 @@ def _area_connectivity(data, n=4):
12231306
area_window[4] = out[min(y + 1, rows - 1), x]
12241307
area_window[5] = out[max(y - 1, 0), min(x + 1, cols - 1)]
12251308
area_window[6] = out[y, min(x + 1, cols - 1)]
1226-
area_window[7] = out[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
1309+
area_window[7] = out[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
12271310

12281311
else:
12291312
src_window[0] = data[y, max(x - 1, 0)]
@@ -1272,7 +1355,7 @@ def _area_connectivity(data, n=4):
12721355
src_window[4] = data[min(y + 1, rows - 1), x]
12731356
src_window[5] = data[max(y - 1, 0), min(x + 1, cols - 1)]
12741357
src_window[6] = data[y, min(x + 1, cols - 1)]
1275-
src_window[7] = data[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
1358+
src_window[7] = data[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
12761359

12771360
area_window[0] = out[max(y - 1, 0), max(x - 1, 0)]
12781361
area_window[1] = out[y, max(x - 1, 0)]
@@ -1281,7 +1364,7 @@ def _area_connectivity(data, n=4):
12811364
area_window[4] = out[min(y + 1, rows - 1), x]
12821365
area_window[5] = out[max(y - 1, 0), min(x + 1, cols - 1)]
12831366
area_window[6] = out[y, min(x + 1, cols - 1)]
1284-
area_window[7] = out[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
1367+
area_window[7] = out[min(y + 1, rows - 1), min(x + 1, cols - 1)] # noqa
12851368

12861369
else:
12871370
src_window[0] = data[y, max(x - 1, 0)]

0 commit comments

Comments
 (0)