1
+ # standard library
1
2
from math import sqrt
2
3
from typing import Optional , Callable , Union , Dict , List
3
4
5
+ # 3rd-party
6
+ import dask .array as da
7
+ import dask .dataframe as dd
8
+ from dask import delayed
9
+
4
10
import numpy as np
5
11
import pandas as pd
6
12
import xarray as xr
7
13
from xarray import DataArray
8
14
9
- import dask .array as da
10
- import dask .dataframe as dd
11
- from dask import delayed
15
+ try :
16
+ import cupy
17
+ except ImportError :
18
+ class cupy (object ):
19
+ ndarray = False
12
20
13
- from xrspatial .utils import ngjit , validate_arrays
21
+ # local modules
22
+ from xrspatial .utils import ngjit
23
+ from xrspatial .utils import validate_arrays
14
24
from xrspatial .utils import ArrayTypeFunctionMapping
15
25
from xrspatial .utils import not_implemented_func
16
26
17
-
18
27
TOTAL_COUNT = '_total_count'
19
28
20
29
21
30
def _stats_count (data ):
22
31
if isinstance (data , np .ndarray ):
23
32
# numpy case
24
33
stats_count = np .ma .count (data )
34
+ elif isinstance (data , cupy .ndarray ):
35
+ # cupy case
36
+ stats_count = np .prod (data .shape )
25
37
else :
26
38
# dask case
27
39
stats_count = data .size - da .ma .getmaskarray (data ).sum ()
@@ -56,9 +68,9 @@ def _stats_count(data):
56
68
sum_squares = lambda block_sum_squares : np .nansum (block_sum_squares , axis = 0 ),
57
69
squared_sum = lambda block_sums : np .nansum (block_sums , axis = 0 )** 2 ,
58
70
)
59
- _dask_mean = lambda sums , counts : sums / counts # noqa
60
- _dask_std = lambda sum_squares , squared_sum , n : np .sqrt ((sum_squares - squared_sum / n ) / n ) # noqa
61
- _dask_var = lambda sum_squares , squared_sum , n : (sum_squares - squared_sum / n ) / n # noqa
71
+ def _dask_mean ( sums , counts ): return sums / counts # noqa
72
+ def _dask_std ( sum_squares , squared_sum , n ): return np .sqrt ((sum_squares - squared_sum / n ) / n ) # noqa
73
+ def _dask_var ( sum_squares , squared_sum , n ): return (sum_squares - squared_sum / n ) / n # noqa
62
74
63
75
64
76
@ngjit
@@ -282,6 +294,81 @@ def _stats_numpy(
282
294
return stats_df
283
295
284
296
297
+ def _stats_cupy (
298
+ orig_zones : xr .DataArray ,
299
+ orig_values : xr .DataArray ,
300
+ zone_ids : List [Union [int , float ]],
301
+ stats_funcs : Dict ,
302
+ nodata_values : Union [int , float ],
303
+ ) -> pd .DataFrame :
304
+
305
+ # TODO add support for 3D input
306
+ if len (orig_values .shape ) > 2 :
307
+ raise TypeError ('3D inputs not supported for cupy backend' )
308
+
309
+ zones = cupy .ravel (orig_zones )
310
+ values = cupy .ravel (orig_values )
311
+
312
+ sorted_indices = cupy .argsort (zones )
313
+
314
+ sorted_zones = zones [sorted_indices ]
315
+ values_by_zone = values [sorted_indices ]
316
+
317
+ # filter out values that are non-finite or values equal to nodata_values
318
+ if nodata_values :
319
+ filter_values = cupy .isfinite (values_by_zone ) & (
320
+ values_by_zone != nodata_values )
321
+ else :
322
+ filter_values = cupy .isfinite (values_by_zone )
323
+ values_by_zone = values_by_zone [filter_values ]
324
+ sorted_zones = sorted_zones [filter_values ]
325
+
326
+ # Now I need to find the unique zones, and zone breaks
327
+ unique_zones , unique_index = cupy .unique (sorted_zones , return_index = True )
328
+
329
+ # Transfer to the host
330
+ unique_index = unique_index .get ()
331
+ if zone_ids is None :
332
+ unique_zones = unique_zones .get ()
333
+ else :
334
+ unique_zones = zone_ids
335
+ # unique_zones = list(map(_to_int, unique_zones))
336
+ unique_zones = np .asarray (unique_zones )
337
+
338
+ # stats columns
339
+ stats_dict = {'zone' : []}
340
+ for stats in stats_funcs :
341
+ stats_dict [stats ] = []
342
+
343
+ for i in range (len (unique_zones )):
344
+ zone_id = unique_zones [i ]
345
+ # skip zone_id == nodata_zones, and non-finite zone ids
346
+ if not np .isfinite (zone_id ):
347
+ continue
348
+
349
+ stats_dict ['zone' ].append (zone_id )
350
+ # extract zone_values
351
+ if i < len (unique_zones ) - 1 :
352
+ zone_values = values_by_zone [unique_index [i ]:unique_index [i + 1 ]]
353
+ else :
354
+ zone_values = values_by_zone [unique_index [i ]:]
355
+
356
+ # apply stats on the zone data
357
+ for j , stats in enumerate (stats_funcs ):
358
+ stats_func = stats_funcs .get (stats )
359
+ if not callable (stats_func ):
360
+ raise ValueError (stats )
361
+ result = stats_func (zone_values )
362
+
363
+ assert (len (result .shape ) == 0 )
364
+
365
+ stats_dict [stats ].append (cupy .float (result ))
366
+
367
+ stats_df = pd .DataFrame (stats_dict )
368
+ stats_df .set_index ("zone" )
369
+ return stats_df
370
+
371
+
285
372
def stats (
286
373
zones : xr .DataArray ,
287
374
values : xr .DataArray ,
@@ -461,13 +548,11 @@ def stats(
461
548
if isinstance (stats_funcs , list ):
462
549
# create a dict of stats
463
550
stats_funcs_dict = {}
464
-
465
551
for stats in stats_funcs :
466
552
func = _DEFAULT_STATS .get (stats , None )
467
553
if func is None :
468
554
err_str = f"Invalid stat name. { stats } option not supported."
469
555
raise ValueError (err_str )
470
-
471
556
stats_funcs_dict [stats ] = func
472
557
473
558
elif isinstance (stats_funcs , dict ):
@@ -476,9 +561,7 @@ def stats(
476
561
mapper = ArrayTypeFunctionMapping (
477
562
numpy_func = _stats_numpy ,
478
563
dask_func = _stats_dask_numpy ,
479
- cupy_func = lambda * args : not_implemented_func (
480
- * args , messages = 'stats() does not support cupy backed DataArray'
481
- ),
564
+ cupy_func = _stats_cupy ,
482
565
dask_cupy_func = lambda * args : not_implemented_func (
483
566
* args , messages = 'stats() does not support dask with cupy backed DataArray' # noqa
484
567
),
@@ -841,13 +924,13 @@ def crosstab(
841
924
>>> df = crosstab(zones=zones_dask, values=values_dask)
842
925
>>> print(df)
843
926
Dask DataFrame Structure:
844
- zone 0.0 10.0 20.0 30.0 40.0 50.0
927
+ zone 0.0 10.0 20.0 30.0 40.0 50.0
845
928
npartitions=5
846
- 0 float64 int64 int64 int64 int64 int64 int64
847
- 1 ... ... ... ... ... ... ...
848
- ... ... ... ... ... ... ... ...
849
- 4 ... ... ... ... ... ... ...
850
- 5 ... ... ... ... ... ... ...
929
+ 0 float64 int64 int64 int64 int64 int64 int64
930
+ 1 ... ... ... ... ... ... ...
931
+ ... ... ... ... ... ... ... ...
932
+ 4 ... ... ... ... ... ... ...
933
+ 5 ... ... ... ... ... ... ...
851
934
Dask Name: astype, 1186 tasks
852
935
>>> print(dask_df.compute)
853
936
zone 0.0 10.0 20.0 30.0 40.0 50.0
@@ -1214,7 +1297,7 @@ def _area_connectivity(data, n=4):
1214
1297
src_window [4 ] = data [min (y + 1 , rows - 1 ), x ]
1215
1298
src_window [5 ] = data [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
1216
1299
src_window [6 ] = data [y , min (x + 1 , cols - 1 )]
1217
- src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1300
+ src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1218
1301
1219
1302
area_window [0 ] = out [max (y - 1 , 0 ), max (x - 1 , 0 )]
1220
1303
area_window [1 ] = out [y , max (x - 1 , 0 )]
@@ -1223,7 +1306,7 @@ def _area_connectivity(data, n=4):
1223
1306
area_window [4 ] = out [min (y + 1 , rows - 1 ), x ]
1224
1307
area_window [5 ] = out [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
1225
1308
area_window [6 ] = out [y , min (x + 1 , cols - 1 )]
1226
- area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1309
+ area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1227
1310
1228
1311
else :
1229
1312
src_window [0 ] = data [y , max (x - 1 , 0 )]
@@ -1272,7 +1355,7 @@ def _area_connectivity(data, n=4):
1272
1355
src_window [4 ] = data [min (y + 1 , rows - 1 ), x ]
1273
1356
src_window [5 ] = data [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
1274
1357
src_window [6 ] = data [y , min (x + 1 , cols - 1 )]
1275
- src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1358
+ src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1276
1359
1277
1360
area_window [0 ] = out [max (y - 1 , 0 ), max (x - 1 , 0 )]
1278
1361
area_window [1 ] = out [y , max (x - 1 , 0 )]
@@ -1281,7 +1364,7 @@ def _area_connectivity(data, n=4):
1281
1364
area_window [4 ] = out [min (y + 1 , rows - 1 ), x ]
1282
1365
area_window [5 ] = out [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
1283
1366
area_window [6 ] = out [y , min (x + 1 , cols - 1 )]
1284
- area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1367
+ area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1285
1368
1286
1369
else :
1287
1370
src_window [0 ] = data [y , max (x - 1 , 0 )]
0 commit comments