Skip to content

Commit 9d2ee7c

Browse files
authored
zonal stats: speed up dask case (#572)
* safely removed nodata_zones arg * dask zonal stats * dask case: support zone_ids * refactor * update docs * clean code
1 parent 1906a4e commit 9d2ee7c

File tree

2 files changed

+239
-119
lines changed

2 files changed

+239
-119
lines changed

xrspatial/tests/test_zonal.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,14 @@ def check_results(df_np, df_da, expected_results_dict):
5656
df_np[col], expected_results_dict[col], equal_nan=True
5757
).all()
5858

59-
# dask case
60-
assert isinstance(df_da, dd.DataFrame)
61-
df_da = df_da.compute()
62-
assert isinstance(df_da, pd.DataFrame)
63-
64-
# numpy results equal dask results
65-
# zone column
66-
assert (df_np['zone'] == df_da['zone']).all()
59+
if df_da is not None:
60+
# dask case
61+
assert isinstance(df_da, dd.DataFrame)
62+
df_da = df_da.compute()
63+
assert isinstance(df_da, pd.DataFrame)
6764

68-
assert (df_np.columns == df_da.columns).all()
69-
for col in df_np.columns[1:]:
70-
assert np.isclose(df_np[col], df_da[col], equal_nan=True).all()
65+
# numpy results equal dask results, ignoring their indexes
66+
assert np.array_equal(df_np.values, df_da.values, equal_nan=True)
7167

7268

7369
def test_stats():
@@ -93,7 +89,27 @@ def test_stats():
9389
df_da = stats(zones=zones_da, values=values_da)
9490
check_results(df_np, df_da, default_stats_results)
9591

96-
# ---- custom stats ----
92+
# expected results
93+
stats_results_zone_0_3 = {
94+
'zone': [0, 3],
95+
'mean': [0, 2.4],
96+
'max': [0, 3],
97+
'min': [0, 0],
98+
'sum': [0, 12],
99+
'std': [0, 1.2],
100+
'var': [0, 1.44],
101+
'count': [5, 5]
102+
}
103+
104+
# numpy case
105+
df_np_zone_0_3 = stats(zones=zones_np, values=values_np, zone_ids=[0, 3])
106+
107+
# dask case
108+
df_da_zone_0_3 = stats(zones=zones_da, values=values_da, zone_ids=[0, 3])
109+
110+
check_results(df_np_zone_0_3, df_da_zone_0_3, stats_results_zone_0_3)
111+
112+
# ---- custom stats (NumPy only) ----
97113
# expected results
98114
custom_stats_results = {
99115
'zone': [1, 2],
@@ -115,13 +131,10 @@ def _range(values):
115131
# numpy case
116132
df_np = stats(
117133
zones=zones_np, values=values_np, stats_funcs=custom_stats,
118-
zone_ids=[1, 2], nodata_zones=0, nodata_values=0
134+
zone_ids=[1, 2], nodata_values=0
119135
)
120136
# dask case
121-
df_da = stats(
122-
zones=zones_da, values=values_da, stats_funcs=custom_stats,
123-
zone_ids=[1, 2], nodata_zones=0, nodata_values=0
124-
)
137+
df_da = None
125138
check_results(df_np, df_da, custom_stats_results)
126139

127140

0 commit comments

Comments
 (0)