Skip to content

Commit 1ef8d73

Browse files
authored
Nanfiltering in TDigestTask (#667)
1 parent 96b4e7a commit 1ef8d73

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

ml_tools/eolearn/ml_tools/tdigest.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""
22
The module provides an EOTask for the computation of a T-Digest representation of an EOPatch.
3-
Requires installation of `eolearn.ml_tools[TDIGEST]`.
43
54
Copyright (c) 2017- Sinergise and contributors
65
For the full list of contributors, see the CREDITS file in the root directory of this source tree.
@@ -9,7 +8,7 @@
98
"""
109
from functools import partial
1110
from itertools import product
12-
from typing import Any, Callable, Dict, Generator, Iterable, List, Literal, Tuple
11+
from typing import Any, Callable, Dict, Generator, Iterable, List, Literal, Tuple, Union
1312

1413
import numpy as np
1514
import tdigest as td
@@ -34,8 +33,9 @@ def __init__(
3433
self,
3534
in_feature: FeaturesSpecification,
3635
out_feature: FeaturesSpecification,
37-
mode: Literal["standard", "timewise", "monthly", "total"] = "standard",
36+
mode: Union[Literal["standard", "timewise", "monthly", "total"], Callable] = "standard",
3837
pixelwise: bool = False,
38+
filternan: bool = False,
3939
):
4040
"""
4141
:param in_feature: The input feature to compute the T-Digest representation for.
@@ -46,14 +46,19 @@ def __init__(
4646
* `'monthly'` computes the T-Digest representation for each band accumulating the timestamps per month.
4747
* | `'total'` computes the total T-Digest representation of the whole feature accumulating all timestamps,
4848
| bands and pixels. Cannot be used with `pixelwise=True`.
49+
* | Callable computes the T-Digest representation defined by the processing function given as mode. Receives
50+
| the input_array of the feature, the timestamps, the shape and the pixelwise and filternan keywords as an input.
4951
:param pixelwise: Decider whether to compute the T-Digest representation accumulating pixels or per pixel.
5052
Cannot be used with `mode='total'`.
53+
:param filternan: Decider whether to filter out nan-values before computing the T-Digest.
5154
"""
5255

5356
self.mode = mode
5457

5558
self.pixelwise = pixelwise
5659

60+
self.filternan = filternan
61+
5762
if self.pixelwise and self.mode == "total":
5863
raise ValueError("Total mode does not support pixelwise=True.")
5964

@@ -78,8 +83,8 @@ def execute(self, eopatch: EOPatch) -> EOPatch:
7883
for in_feature_, out_feature_, shape in _looper(
7984
in_feature=self.in_feature, out_feature=self.out_feature, eopatch=eopatch
8085
):
81-
eopatch[out_feature_] = _processing_function[self.mode](
82-
input_array=eopatch[in_feature_], timestamps=eopatch.timestamps, shape=shape, pixelwise=self.pixelwise
86+
eopatch[out_feature_] = _processing_function.get(self.mode, self.mode)(
87+
input_array=eopatch[in_feature_], timestamps=eopatch.timestamps, shape=shape, pixelwise=self.pixelwise, filternan=self.filternan
8388
)
8489

8590
return eopatch
@@ -95,6 +100,9 @@ def _is_input_ftype(feature_type: FeatureType, mode: ModeTypes) -> bool:
95100

96101

97102
def _is_output_ftype(feature_type: FeatureType, mode: ModeTypes, pixelwise: bool) -> bool:
103+
if callable(mode):
104+
return True
105+
98106
if mode == "standard":
99107
return feature_type == (FeatureType.DATA_TIMELESS if pixelwise else FeatureType.SCALAR_TIMELESS)
100108

@@ -112,36 +120,36 @@ def _looper(
112120
yield in_feature_, out_feature_, shape
113121

114122

115-
def _process_standard(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, **_: Any) -> np.ndarray:
123+
def _process_standard(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any) -> np.ndarray:
116124
if pixelwise:
117125
array = np.empty(shape[-3:], dtype=object)
118126
for i, j, k in product(range(shape[-3]), range(shape[-2]), range(shape[-1])):
119-
array[i, j, k] = _get_tdigest(input_array[..., i, j, k])
127+
array[i, j, k] = _get_tdigest(input_array[..., i, j, k], filternan)
120128

121129
else:
122130
array = np.empty(shape[-1], dtype=object)
123131
for k in range(shape[-1]):
124-
array[k] = _get_tdigest(input_array[..., k])
132+
array[k] = _get_tdigest(input_array[..., k], filternan)
125133

126134
return array
127135

128136

129-
def _process_timewise(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, **_: Any) -> np.ndarray:
137+
def _process_timewise(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any) -> np.ndarray:
130138
if pixelwise:
131139
array = np.empty(shape, dtype=object)
132140
for time_, i, j, k in product(range(shape[0]), range(shape[1]), range(shape[2]), range(shape[3])):
133-
array[time_, i, j, k] = _get_tdigest(input_array[time_, i, j, k])
141+
array[time_, i, j, k] = _get_tdigest(input_array[time_, i, j, k], filternan)
134142

135143
else:
136144
array = np.empty(shape[[0, -1]], dtype=object)
137145
for time_, k in product(range(shape[0]), range(shape[-1])):
138-
array[time_, k] = _get_tdigest(input_array[time_, ..., k])
146+
array[time_, k] = _get_tdigest(input_array[time_, ..., k], filternan)
139147

140148
return array
141149

142150

143151
def _process_monthly(
144-
input_array: np.ndarray, timestamps: Iterable, shape: np.ndarray, pixelwise: bool, **_: Any
152+
input_array: np.ndarray, timestamps: Iterable, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any
145153
) -> np.ndarray:
146154
midx = []
147155
for month_ in range(12):
@@ -150,18 +158,18 @@ def _process_monthly(
150158
if pixelwise:
151159
array = np.empty([12, *shape[1:]], dtype=object)
152160
for month_, i, j, k in product(range(12), range(shape[1]), range(shape[2]), range(shape[3])):
153-
array[month_, i, j, k] = _get_tdigest(input_array[midx[month_], i, j, k])
161+
array[month_, i, j, k] = _get_tdigest(input_array[midx[month_], i, j, k], filternan)
154162

155163
else:
156164
array = np.empty([12, shape[-1]], dtype=object)
157165
for month_, k in product(range(12), range(shape[-1])):
158-
array[month_, k] = _get_tdigest(input_array[midx[month_], ..., k])
166+
array[month_, k] = _get_tdigest(input_array[midx[month_], ..., k], filternan)
159167

160168
return array
161169

162170

163-
def _process_total(input_array: np.ndarray, **_: Any) -> np.ndarray:
164-
return _get_tdigest(input_array)
171+
def _process_total(input_array: np.ndarray, filternan: bool, **_: Any) -> np.ndarray:
172+
return _get_tdigest(input_array, filternan)
165173

166174

167175
_processing_function: Dict[str, Callable] = {
@@ -172,7 +180,8 @@ def _process_total(input_array: np.ndarray, **_: Any) -> np.ndarray:
172180
}
173181

174182

175-
def _get_tdigest(values: np.ndarray) -> td.TDigest:
183+
def _get_tdigest(values: np.ndarray, filternan: bool) -> td.TDigest:
176184
result = td.TDigest()
177-
result.batch_update(values.flatten())
185+
values_ = values.flatten()
186+
result.batch_update(values_[~np.isnan(values_)] if filternan else values_)
178187
return result

0 commit comments

Comments
 (0)