Skip to content

Commit cb768bd

Browse files
authored
validate_arrays: ensure chunksizes of arrays are matching (#577)
1 parent 376e537 commit cb768bd

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

xrspatial/classify.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def _run_cupy_jenks_matrices(data, n_classes):
585585
nl = data.shape[0] + 1
586586
variance = 0.0
587587

588-
for l in range(2, nl): # noqa
588+
for l in range(2, nl): # noqa
589589
sum = 0.0
590590
sum_squares = 0.0
591591
w = 0.0

xrspatial/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ def validate_arrays(*arrays):
129129
if not isinstance(first_array.data, type(arrays[i].data)):
130130
raise ValueError("input arrays must have same type")
131131

132+
# ensure dask chunksizes of all arrays are the same
133+
if isinstance(first_array.data, da.Array):
134+
for i in range(1, len(arrays)):
135+
if first_array.chunks != arrays[i].chunks:
136+
arrays[i].data = arrays[i].data.rechunk(first_array.chunks)
137+
132138

133139
def get_xy_range(raster, xdim=None, ydim=None):
134140
"""

xrspatial/zonal.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import dask.dataframe as dd
1111
from dask import delayed
1212

13-
from xrspatial.utils import ngjit
13+
from xrspatial.utils import ngjit, validate_arrays
1414

1515

1616
def _stats_count(data):
@@ -472,14 +472,13 @@ def stats(
472472
3 30 3850
473473
"""
474474

475-
if zones.shape != values.shape:
476-
raise ValueError("`zones` and `values` must have same shape.")
475+
validate_arrays(zones, values)
477476

478477
if not (
479478
issubclass(zones.data.dtype.type, np.integer)
480479
or issubclass(zones.data.dtype.type, np.floating)
481480
):
482-
raise ValueError("`zones` must be an array of integers.")
481+
raise ValueError("`zones` must be an array of integers or floats.")
483482

484483
if not (
485484
issubclass(values.data.dtype.type, np.integer)
@@ -521,11 +520,6 @@ def stats(
521520
)
522521
else:
523522
# dask case
524-
525-
# make sure chunksizes of `zones` and `values` are matching
526-
if zones.chunks != values.chunks:
527-
values.data = values.data.rechunk(zones.chunks)
528-
529523
stats_df = _stats_dask_numpy(
530524
zones.data,
531525
values.data,

0 commit comments

Comments
 (0)