1
1
"""
2
2
The module provides an EOTask for the computation of a T-Digest representation of an EOPatch.
3
- Requires installation of `eolearn.ml_tools[TDIGEST]`.
4
3
5
4
Copyright (c) 2017- Sinergise and contributors
6
5
For the full list of contributors, see the CREDITS file in the root directory of this source tree.
9
8
"""
10
9
from functools import partial
11
10
from itertools import product
12
- from typing import Any , Callable , Dict , Generator , Iterable , List , Literal , Tuple
11
+ from typing import Any , Callable , Dict , Generator , Iterable , List , Literal , Tuple , Union
13
12
14
13
import numpy as np
15
14
import tdigest as td
@@ -34,8 +33,9 @@ def __init__(
34
33
self ,
35
34
in_feature : FeaturesSpecification ,
36
35
out_feature : FeaturesSpecification ,
37
- mode : Literal ["standard" , "timewise" , "monthly" , "total" ] = "standard" ,
36
+ mode : Union [ Literal ["standard" , "timewise" , "monthly" , "total" ], Callable ] = "standard" ,
38
37
pixelwise : bool = False ,
38
+ filternan : bool = False ,
39
39
):
40
40
"""
41
41
:param in_feature: The input feature to compute the T-Digest representation for.
@@ -46,14 +46,19 @@ def __init__(
46
46
* `'monthly'` computes the T-Digest representation for each band accumulating the timestamps per month.
47
47
* | `'total'` computes the total T-Digest representation of the whole feature accumulating all timestamps,
48
48
| bands and pixels. Cannot be used with `pixelwise=True`.
49
+ * | Callable computes the T-Digest representation defined by the processing function given as mode. Receives
50
+ | the input_array of the feature, the timestamps, the shape and the pixelwise and filternan keywords as an input.
49
51
:param pixelwise: Decider whether to compute the T-Digest representation accumulating pixels or per pixel.
50
52
Cannot be used with `mode='total'`.
53
+ :param filternan: Decider whether to filter out nan-values before computing the T-Digest.
51
54
"""
52
55
53
56
self .mode = mode
54
57
55
58
self .pixelwise = pixelwise
56
59
60
+ self .filternan = filternan
61
+
57
62
if self .pixelwise and self .mode == "total" :
58
63
raise ValueError ("Total mode does not support pixelwise=True." )
59
64
@@ -78,8 +83,8 @@ def execute(self, eopatch: EOPatch) -> EOPatch:
78
83
for in_feature_ , out_feature_ , shape in _looper (
79
84
in_feature = self .in_feature , out_feature = self .out_feature , eopatch = eopatch
80
85
):
81
- eopatch [out_feature_ ] = _processing_function [ self .mode ] (
82
- input_array = eopatch [in_feature_ ], timestamps = eopatch .timestamps , shape = shape , pixelwise = self .pixelwise
86
+ eopatch [out_feature_ ] = _processing_function . get ( self .mode , self . mode ) (
87
+ input_array = eopatch [in_feature_ ], timestamps = eopatch .timestamps , shape = shape , pixelwise = self .pixelwise , filternan = self . filternan
83
88
)
84
89
85
90
return eopatch
@@ -95,6 +100,9 @@ def _is_input_ftype(feature_type: FeatureType, mode: ModeTypes) -> bool:
95
100
96
101
97
102
def _is_output_ftype (feature_type : FeatureType , mode : ModeTypes , pixelwise : bool ) -> bool :
103
+ if callable (mode ):
104
+ return True
105
+
98
106
if mode == "standard" :
99
107
return feature_type == (FeatureType .DATA_TIMELESS if pixelwise else FeatureType .SCALAR_TIMELESS )
100
108
@@ -112,36 +120,36 @@ def _looper(
112
120
yield in_feature_ , out_feature_ , shape
113
121
114
122
115
- def _process_standard (input_array : np .ndarray , shape : np .ndarray , pixelwise : bool , ** _ : Any ) -> np .ndarray :
123
+ def _process_standard (input_array : np .ndarray , shape : np .ndarray , pixelwise : bool , filternan : bool , ** _ : Any ) -> np .ndarray :
116
124
if pixelwise :
117
125
array = np .empty (shape [- 3 :], dtype = object )
118
126
for i , j , k in product (range (shape [- 3 ]), range (shape [- 2 ]), range (shape [- 1 ])):
119
- array [i , j , k ] = _get_tdigest (input_array [..., i , j , k ])
127
+ array [i , j , k ] = _get_tdigest (input_array [..., i , j , k ], filternan )
120
128
121
129
else :
122
130
array = np .empty (shape [- 1 ], dtype = object )
123
131
for k in range (shape [- 1 ]):
124
- array [k ] = _get_tdigest (input_array [..., k ])
132
+ array [k ] = _get_tdigest (input_array [..., k ], filternan )
125
133
126
134
return array
127
135
128
136
129
- def _process_timewise (input_array : np .ndarray , shape : np .ndarray , pixelwise : bool , ** _ : Any ) -> np .ndarray :
137
+ def _process_timewise (input_array : np .ndarray , shape : np .ndarray , pixelwise : bool , filternan : bool , ** _ : Any ) -> np .ndarray :
130
138
if pixelwise :
131
139
array = np .empty (shape , dtype = object )
132
140
for time_ , i , j , k in product (range (shape [0 ]), range (shape [1 ]), range (shape [2 ]), range (shape [3 ])):
133
- array [time_ , i , j , k ] = _get_tdigest (input_array [time_ , i , j , k ])
141
+ array [time_ , i , j , k ] = _get_tdigest (input_array [time_ , i , j , k ], filternan )
134
142
135
143
else :
136
144
array = np .empty (shape [[0 , - 1 ]], dtype = object )
137
145
for time_ , k in product (range (shape [0 ]), range (shape [- 1 ])):
138
- array [time_ , k ] = _get_tdigest (input_array [time_ , ..., k ])
146
+ array [time_ , k ] = _get_tdigest (input_array [time_ , ..., k ], filternan )
139
147
140
148
return array
141
149
142
150
143
151
def _process_monthly (
144
- input_array : np .ndarray , timestamps : Iterable , shape : np .ndarray , pixelwise : bool , ** _ : Any
152
+ input_array : np .ndarray , timestamps : Iterable , shape : np .ndarray , pixelwise : bool , filternan : bool , ** _ : Any
145
153
) -> np .ndarray :
146
154
midx = []
147
155
for month_ in range (12 ):
@@ -150,18 +158,18 @@ def _process_monthly(
150
158
if pixelwise :
151
159
array = np .empty ([12 , * shape [1 :]], dtype = object )
152
160
for month_ , i , j , k in product (range (12 ), range (shape [1 ]), range (shape [2 ]), range (shape [3 ])):
153
- array [month_ , i , j , k ] = _get_tdigest (input_array [midx [month_ ], i , j , k ])
161
+ array [month_ , i , j , k ] = _get_tdigest (input_array [midx [month_ ], i , j , k ], filternan )
154
162
155
163
else :
156
164
array = np .empty ([12 , shape [- 1 ]], dtype = object )
157
165
for month_ , k in product (range (12 ), range (shape [- 1 ])):
158
- array [month_ , k ] = _get_tdigest (input_array [midx [month_ ], ..., k ])
166
+ array [month_ , k ] = _get_tdigest (input_array [midx [month_ ], ..., k ], filternan )
159
167
160
168
return array
161
169
162
170
163
- def _process_total (input_array : np .ndarray , ** _ : Any ) -> np .ndarray :
164
- return _get_tdigest (input_array )
171
+ def _process_total (input_array : np .ndarray , filternan : bool , ** _ : Any ) -> np .ndarray :
172
+ return _get_tdigest (input_array , filternan )
165
173
166
174
167
175
_processing_function : Dict [str , Callable ] = {
@@ -172,7 +180,8 @@ def _process_total(input_array: np.ndarray, **_: Any) -> np.ndarray:
172
180
}
173
181
174
182
175
- def _get_tdigest (values : np .ndarray ) -> td .TDigest :
183
+ def _get_tdigest (values : np .ndarray , filternan : bool ) -> td .TDigest :
176
184
result = td .TDigest ()
177
- result .batch_update (values .flatten ())
185
+ values_ = values .flatten ()
186
+ result .batch_update (values_ [~ np .isnan (values_ )] if filternan else values_ )
178
187
return result
0 commit comments