Skip to content

Commit 7706787

Browse files
committed
add future annotations to improve docs in tdigest
1 parent 085861d commit 7706787

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

ml_tools/eolearn/ml_tools/tdigest.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
88
This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree.
99
"""
10+
from __future__ import annotations
11+
1012
from functools import partial
1113
from itertools import product
12-
from typing import Any, Callable, Dict, Generator, Iterable, List, Literal, Tuple, Union
14+
from typing import Any, Callable, Generator, Iterable, Literal, Union
1315

1416
import numpy as np
1517
import tdigest as td
@@ -34,7 +36,7 @@ def __init__(
3436
self,
3537
in_feature: FeaturesSpecification,
3638
out_feature: FeaturesSpecification,
37-
mode: Union[Literal["standard", "timewise", "monthly", "total"], Callable] = "standard",
39+
mode: Literal["standard", "timewise", "monthly", "total"] | Callable = "standard",
3840
pixelwise: bool = False,
3941
filternan: bool = False,
4042
):
@@ -119,8 +121,8 @@ def _is_output_ftype(feature_type: FeatureType, mode: ModeTypes, pixelwise: bool
119121

120122

121123
def _looper(
122-
in_feature: List[FeatureSpec], out_feature: List[FeatureSpec], eopatch: EOPatch
123-
) -> Generator[Tuple[FeatureSpec, FeatureSpec, np.ndarray], None, None]:
124+
in_feature: list[FeatureSpec], out_feature: list[FeatureSpec], eopatch: EOPatch
125+
) -> Generator[tuple[FeatureSpec, FeatureSpec, np.ndarray], None, None]:
124126
for in_feature_, out_feature_ in zip(in_feature, out_feature):
125127
shape = np.array(eopatch[in_feature_].shape)
126128
yield in_feature_, out_feature_, shape
@@ -182,7 +184,7 @@ def _process_total(input_array: np.ndarray, filternan: bool, **_: Any) -> np.nda
182184
return _get_tdigest(input_array, filternan)
183185

184186

185-
_processing_function: Dict[str, Callable] = {
187+
_processing_function: dict[str, Callable] = {
186188
"standard": _process_standard,
187189
"timewise": _process_timewise,
188190
"monthly": _process_monthly,

0 commit comments

Comments
 (0)