Skip to content

Commit d9375cf

Browse files
jgersakzigaLuksic
andauthored
Apply patch generator (#621)
* Apply_patch_generator to test_parsing * Add generate_eopatch to test eodata_io * Update test_eodata with generate_eopatch * Update test_core_tasks with generate_eopatch * Update test_train_split with generate_eopatch * Update test_sampling with generate_eopatch * Update mini_eopatch feature names * Split patch_fixture on two * Rename fixtuer patch_bands to eopatch_to_explode and update CLP_S2C in patch_fixture as suggested * Update CLP_S2C * Update test_get_spatial_dimension, update fixtuer in test_sampling * Update core/eolearn/tests/test_core_tasks.py Co-authored-by: Žiga Lukšič <[email protected]> * Update core/eolearn/tests/test_eodata.py Co-authored-by: Žiga Lukšič <[email protected]> --------- Co-authored-by: Žiga Lukšič <[email protected]>
1 parent a7e3215 commit d9375cf

File tree

6 files changed

+90
-96
lines changed

6 files changed

+90
-96
lines changed

core/eolearn/tests/test_core_tasks.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,34 +44,32 @@
4444
from eolearn.core.core_tasks import ExplodeBandsTask
4545
from eolearn.core.types import FeatureRenameSpec, FeatureSpec, FeaturesSpecification
4646
from eolearn.core.utils.parsing import parse_features
47-
from eolearn.core.utils.testing import assert_feature_data_equal
47+
from eolearn.core.utils.testing import PatchGeneratorConfig, assert_feature_data_equal, generate_eopatch
4848

4949
DUMMY_BBOX = BBox((0, 0, 1, 1), CRS(3857))
5050

5151

5252
@pytest.fixture(name="patch")
5353
def patch_fixture() -> EOPatch:
54-
patch = EOPatch(bbox=BBox((324.54, 546.45, 955.4, 63.43), CRS(3857)))
55-
patch.data["bands"] = np.arange(5 * 3 * 4 * 8).reshape(5, 3, 4, 8)
56-
patch.data["CLP"] = np.full((5, 3, 4, 1), 0.7)
57-
patch.data["CLP_S2C"] = np.zeros((5, 3, 4, 1), dtype=np.int64)
58-
patch.mask["CLM"] = np.full((5, 3, 4, 1), True)
59-
patch.mask_timeless["mask"] = np.arange(3 * 4 * 2).reshape(3, 4, 2)
60-
patch.mask_timeless["LULC"] = np.zeros((3, 4, 1), dtype=np.uint16)
61-
patch.mask_timeless["RANDOM_UINT8"] = np.random.randint(0, 100, size=(3, 4, 1), dtype=np.int8)
62-
patch.scalar["values"] = np.arange(10 * 5).reshape(5, 10)
63-
patch.scalar["CLOUD_COVERAGE"] = np.ones((5, 10))
64-
patch.timestamps = [
65-
datetime(2017, 1, 14, 10, 13, 46),
66-
datetime(2017, 2, 10, 10, 1, 32),
67-
datetime(2017, 2, 20, 10, 6, 35),
68-
datetime(2017, 3, 2, 10, 0, 20),
69-
datetime(2017, 3, 12, 10, 7, 6),
70-
]
71-
patch.meta_info["something"] = np.random.rand(10, 1)
54+
patch = generate_eopatch(
55+
{
56+
FeatureType.DATA: ["bands", "CLP"],
57+
FeatureType.MASK: ["CLM"],
58+
FeatureType.MASK_TIMELESS: ["mask", "LULC", "RANDOM_UINT8"],
59+
FeatureType.SCALAR: ["values", "CLOUD_COVERAGE"],
60+
}
61+
)
62+
patch.data["CLP_S2C"] = np.zeros_like(patch.data["CLP"])
63+
64+
patch.meta_info["something"] = "beep boop"
7265
return patch
7366

7467

68+
@pytest.fixture(name="eopatch_to_explode")
69+
def eopatch_to_explode_fixture() -> EOPatch:
70+
return generate_eopatch((FeatureType.DATA, "bands"), config=PatchGeneratorConfig(depth_range=(8, 9)))
71+
72+
7573
@pytest.mark.parametrize("task", [DeepCopyTask, CopyTask])
7674
def test_copy(task: Type[CopyTask], patch: EOPatch) -> None:
7775
patch_copy = task().execute(patch)
@@ -413,11 +411,11 @@ def kwargs_map(data, *, some=3, **kwargs) -> tuple:
413411
],
414412
)
415413
def test_explode_bands(
416-
patch: EOPatch,
414+
eopatch_to_explode: EOPatch,
417415
feature: Tuple[FeatureType, str],
418416
task_input: Dict[Tuple[FeatureType, str], Union[int, Iterable[int]]],
419417
) -> None:
420-
patch = ExplodeBandsTask(feature, task_input)(patch)
418+
patch = ExplodeBandsTask(feature, task_input)(eopatch_to_explode)
421419
assert all(new_feature in patch for new_feature in task_input)
422420

423421
for new_feature, bands in task_input.items():
@@ -426,19 +424,23 @@ def test_explode_bands(
426424
assert_array_equal(patch[new_feature], patch[feature][..., bands])
427425

428426

429-
def test_extract_bands(patch: EOPatch) -> None:
427+
def test_extract_bands(eopatch_to_explode: EOPatch) -> None:
430428
bands = [2, 4, 6]
431-
patch = ExtractBandsTask((FeatureType.DATA, "bands"), (FeatureType.DATA, "EXTRACTED_BANDS"), bands)(patch)
429+
patch = ExtractBandsTask((FeatureType.DATA, "bands"), (FeatureType.DATA, "EXTRACTED_BANDS"), bands)(
430+
eopatch_to_explode
431+
)
432432
assert_array_equal(patch.data["EXTRACTED_BANDS"], patch.data["bands"][..., bands])
433433

434434
patch.data["EXTRACTED_BANDS"][0, 0, 0, 0] += 1
435435
assert patch.data["EXTRACTED_BANDS"][0, 0, 0, 0] != patch.data["bands"][0, 0, 0, bands[0]]
436436

437437

438-
def test_extract_bands_fails(patch: EOPatch) -> None:
438+
def test_extract_bands_fails(eopatch_to_explode: EOPatch) -> None:
439439
with pytest.raises(ValueError):
440440
# fails because band 16 does not exist
441-
ExtractBandsTask((FeatureType.DATA, "bands"), (FeatureType.DATA, "EXTRACTED_BANDS"), [2, 4, 16])(patch)
441+
ExtractBandsTask((FeatureType.DATA, "bands"), (FeatureType.DATA, "EXTRACTED_BANDS"), [2, 4, 16])(
442+
eopatch_to_explode
443+
)
442444

443445

444446
@pytest.mark.parametrize(

core/eolearn/tests/test_eodata.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@
1818
from eolearn.core.eodata_io import FeatureIO
1919
from eolearn.core.exceptions import EODeprecationWarning
2020
from eolearn.core.types import FeatureSpec, FeaturesSpecification
21-
from eolearn.core.utils.testing import assert_feature_data_equal
21+
from eolearn.core.utils.testing import assert_feature_data_equal, generate_eopatch
2222

2323
DUMMY_BBOX = BBox((0, 0, 1, 1), CRS(3857))
2424

2525

2626
@pytest.fixture(name="mini_eopatch")
2727
def mini_eopatch_fixture() -> EOPatch:
28-
eop = EOPatch(bbox=DUMMY_BBOX)
29-
eop.data["bands"] = np.arange(2 * 3 * 3 * 2).reshape(2, 3, 3, 2)
30-
eop.data["zeros"] = np.zeros((2, 3, 3, 2), dtype=float)
31-
eop.mask["ones"] = np.ones((2, 6, 6, 1), dtype=int)
32-
eop.mask["twos"] = np.ones((2, 3, 3, 2), dtype=int) * 2
33-
eop.mask_timeless["threes"] = np.ones((3, 3, 1), dtype=np.uint8) * 3
28+
eop = generate_eopatch(
29+
{
30+
FeatureType.DATA: ["A", "B"],
31+
FeatureType.MASK: ["C", "D"],
32+
FeatureType.MASK_TIMELESS: ["E"],
33+
}
34+
)
3435
eop.meta_info["beep"] = "boop"
3536

3637
return eop
@@ -161,9 +162,9 @@ def test_simplified_feature_operations() -> None:
161162
@pytest.mark.parametrize(
162163
"feature_to_delete",
163164
[
164-
(FeatureType.DATA, "zeros"),
165-
(FeatureType.MASK, "ones"),
166-
(FeatureType.MASK_TIMELESS, "threes"),
165+
(FeatureType.DATA, "A"),
166+
(FeatureType.MASK, "C"),
167+
(FeatureType.MASK_TIMELESS, "E"),
167168
(FeatureType.META_INFO, "beep"),
168169
(FeatureType.TIMESTAMPS, None),
169170
],
@@ -317,18 +318,27 @@ def test_equals() -> None:
317318
assert eop1 != eop2
318319

319320

321+
@pytest.fixture(scope="function", name="eopatch_spatial_dim")
322+
def eopatch_spatial_dim_fixture() -> EOPatch:
323+
patch = EOPatch(bbox=DUMMY_BBOX)
324+
patch.data["A"] = np.zeros((1, 2, 3, 4))
325+
patch.mask["B"] = np.ones((4, 3, 2, 1), dtype=np.uint8)
326+
patch.mask_timeless["C"] = np.zeros((4, 5, 1), dtype=np.uint8)
327+
return patch
328+
329+
320330
@pytest.mark.parametrize(
321331
"feature, expected_dim",
322332
[
323-
[(FeatureType.DATA, "zeros"), (3, 3)],
324-
[(FeatureType.MASK, "ones"), (6, 6)],
325-
[(FeatureType.MASK_TIMELESS, "threes"), (3, 3)],
333+
[(FeatureType.DATA, "A"), (2, 3)],
334+
[(FeatureType.MASK, "B"), (3, 2)],
335+
[(FeatureType.MASK_TIMELESS, "C"), (4, 5)],
326336
],
327337
)
328338
def test_get_spatial_dimension(
329-
feature: Tuple[FeatureType, str], expected_dim: Tuple[int, int], mini_eopatch: EOPatch
339+
feature: Tuple[FeatureType, str], expected_dim: Tuple[int, int], eopatch_spatial_dim: EOPatch
330340
) -> None:
331-
assert mini_eopatch.get_spatial_dimension(*feature) == expected_dim
341+
assert eopatch_spatial_dim.get_spatial_dimension(*feature) == expected_dim
332342

333343

334344
@pytest.mark.parametrize(
@@ -337,13 +347,14 @@ def test_get_spatial_dimension(
337347
(
338348
pytest.lazy_fixture("mini_eopatch"),
339349
[
340-
(FeatureType.DATA, "bands"),
341-
(FeatureType.DATA, "zeros"),
342-
(FeatureType.MASK, "ones"),
343-
(FeatureType.MASK, "twos"),
344-
(FeatureType.MASK_TIMELESS, "threes"),
350+
(FeatureType.DATA, "A"),
351+
(FeatureType.DATA, "B"),
352+
(FeatureType.MASK, "C"),
353+
(FeatureType.MASK, "D"),
354+
(FeatureType.MASK_TIMELESS, "E"),
345355
(FeatureType.META_INFO, "beep"),
346356
(FeatureType.BBOX, None),
357+
(FeatureType.TIMESTAMPS, None),
347358
],
348359
),
349360
(EOPatch(bbox=DUMMY_BBOX), [(FeatureType.BBOX, None)]),

core/eolearn/tests/test_eodata_io.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from eolearn.core.exceptions import EODeprecationWarning
3636
from eolearn.core.types import FeaturesSpecification
3737
from eolearn.core.utils.parsing import FeatureParser
38-
from eolearn.core.utils.testing import assert_feature_data_equal
38+
from eolearn.core.utils.testing import assert_feature_data_equal, generate_eopatch
3939

4040
FS_LOADERS = [TempFS, pytest.lazy_fixture("create_mocked_s3fs")]
4141

@@ -44,16 +44,16 @@
4444

4545
@pytest.fixture(name="eopatch")
4646
def eopatch_fixture():
47-
eopatch = EOPatch(bbox=DUMMY_BBOX)
48-
mask = np.zeros((3, 3, 2), dtype=np.int16)
49-
data = np.zeros((2, 3, 3, 2), dtype=np.int16)
50-
eopatch.mask_timeless["mask"] = mask
51-
eopatch.data["data"] = data
52-
eopatch.timestamps = [datetime.datetime(2017, 1, 1, 10, 4, 7), datetime.datetime(2017, 1, 4, 10, 14, 5)]
47+
eopatch = generate_eopatch(
48+
{
49+
FeatureType.DATA: ["data"],
50+
FeatureType.MASK_TIMELESS: ["mask"],
51+
FeatureType.SCALAR: ["my scalar with spaces"],
52+
FeatureType.SCALAR_TIMELESS: ["my timeless scalar with spaces"],
53+
}
54+
)
5355
eopatch.meta_info["something"] = "nothing"
5456
eopatch.meta_info["something-else"] = "nothing"
55-
eopatch.scalar["my scalar with spaces"] = np.array([[1, 2, 3], [1, 2, 3]])
56-
eopatch.scalar_timeless["my timeless scalar with spaces"] = np.array([1, 2, 3])
5757
eopatch.vector["my-df"] = GeoDataFrame(
5858
{
5959
"values": [1, 2],
@@ -113,7 +113,7 @@ def test_overwriting_non_empty_folder(eopatch, fs_loader):
113113
eopatch.save("/", filesystem=temp_fs, overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)
114114
eopatch.save("/", filesystem=temp_fs, overwrite_permission=OverwritePermission.OVERWRITE_PATCH)
115115

116-
add_eopatch = EOPatch(bbox=DUMMY_BBOX)
116+
add_eopatch = EOPatch(bbox=eopatch.bbox)
117117
add_eopatch.data_timeless["some data"] = np.empty((3, 3, 2))
118118
add_eopatch.save("/", filesystem=temp_fs, overwrite_permission=OverwritePermission.ADD_ONLY)
119119
with pytest.raises(ValueError):

core/eolearn/tests/test_utils/test_parsing.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,11 @@
1-
import datetime as dt
21
from dataclasses import dataclass
32
from typing import Callable, Iterable, List, Optional, Tuple, Union
43

5-
import numpy as np
64
import pytest
75

8-
from sentinelhub import CRS, BBox
9-
106
from eolearn.core import EOPatch, FeatureParser, FeatureType
117
from eolearn.core.types import EllipsisType, FeatureRenameSpec, FeatureSpec, FeaturesSpecification
12-
13-
14-
@pytest.fixture(name="eopatch", scope="module")
15-
def eopatch_fixture():
16-
return EOPatch(
17-
data=dict(data=np.zeros((2, 2, 2, 2)), CLP=np.zeros((2, 2, 2, 2))), # name duplication intentional
18-
bbox=BBox((1, 2, 3, 4), CRS.WGS84),
19-
timestamps=[dt.datetime(2020, 5, 1), dt.datetime(2020, 5, 25)],
20-
mask=dict(data=np.zeros((2, 2, 2, 2), dtype=int), IS_VALID=np.zeros((2, 2, 2, 2), dtype=int)),
21-
mask_timeless=dict(LULC=np.zeros((2, 2, 2), dtype=int)),
22-
meta_info={"something": "else"},
23-
)
8+
from eolearn.core.utils.testing import generate_eopatch
249

2510

2611
@dataclass
@@ -181,6 +166,15 @@ def test_allowed_feature_types_iterable(test_input: FeaturesSpecification, allow
181166
FeatureParser(features=test_input, allowed_feature_types=allowed_types)
182167

183168

169+
@pytest.fixture(name="eopatch", scope="module")
170+
def eopatch_fixture():
171+
patch = generate_eopatch(
172+
{FeatureType.DATA: ["data", "CLP"], FeatureType.MASK: ["data", "IS_VALID"], FeatureType.MASK_TIMELESS: ["LULC"]}
173+
)
174+
patch.meta_info = {"something": "else"}
175+
return patch
176+
177+
184178
@pytest.mark.parametrize(
185179
"test_input, allowed_types",
186180
[

ml_tools/eolearn/tests/test_sampling.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
from pytest import approx
1515
from shapely.geometry import Point, Polygon
1616

17-
from sentinelhub import CRS, BBox
18-
1917
from eolearn.core import EOPatch, EOTask, FeatureType
18+
from eolearn.core.utils.testing import PatchGeneratorConfig, generate_eopatch
2019
from eolearn.ml_tools import BlockSamplingTask, FractionSamplingTask, GridSamplingTask, sample_by_values
2120
from eolearn.ml_tools.sampling import expand_to_grids, get_mask_of_samples, random_point_in_triangle
2221

@@ -134,11 +133,10 @@ def test_get_mask_of_samples(small_image: np.ndarray, n_samples: Dict[int, int])
134133

135134
@pytest.fixture(name="eopatch")
136135
def eopatch_fixture(small_image: np.ndarray) -> EOPatch:
137-
t, h, w, d = 10, *small_image.shape, 5
138-
eopatch = EOPatch(bbox=BBox((0, 0, 1, 1), CRS(3857)))
139-
eopatch.data["bands"] = np.arange(t * h * w * d).reshape(t, h, w, d)
140-
eopatch.mask_timeless["raster"] = small_image.reshape(small_image.shape + (1,))
141-
return eopatch
136+
config = PatchGeneratorConfig(raster_shape=small_image.shape, depth_range=(5, 6), num_timestamps=10)
137+
patch = generate_eopatch([(FeatureType.DATA, "bands")], config=config)
138+
patch.mask_timeless["raster"] = small_image.reshape(small_image.shape + (1,))
139+
return patch
142140

143141

144142
SAMPLING_MASK = FeatureType.MASK_TIMELESS, "sampling_mask"
@@ -159,14 +157,15 @@ def block_task_fixture(request) -> EOTask:
159157
def test_object_sampling_task_mask(
160158
eopatch: EOPatch, small_image: np.ndarray, seed: int, block_task: BlockSamplingTask
161159
) -> None:
162-
t, h, w, d = 10, *small_image.shape, 5
160+
t, h, w, d = eopatch.data["bands"].shape
161+
dr = eopatch.mask_timeless["raster"].shape[2]
163162
amount = block_task.amount
164163

165164
block_task.execute(eopatch, seed=seed)
166165
expected_amount = amount if isinstance(amount, int) else round(np.prod(small_image.shape) * amount)
167166

168167
assert eopatch.data["SAMPLED_DATA"].shape == (t, expected_amount, 1, d)
169-
assert eopatch.mask_timeless["SAMPLED_LABELS"].shape == (expected_amount, 1, 1)
168+
assert eopatch.mask_timeless["SAMPLED_LABELS"].shape == (expected_amount, 1, dr)
170169
assert eopatch.mask_timeless["sampling_mask"].shape == (h, w, 1)
171170

172171
sampled_uniques, sampled_counts = np.unique(eopatch.data["SAMPLED_DATA"], return_counts=True)

ml_tools/eolearn/tests/test_train_split.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import pytest
1212
from numpy.testing import assert_array_equal
1313

14-
from sentinelhub import CRS, BBox
15-
1614
from eolearn.core import EOPatch, FeatureType
15+
from eolearn.core.utils.testing import PatchGeneratorConfig, generate_eopatch
1716
from eolearn.ml_tools.train_test_split import TrainTestSplitTask, TrainTestSplitType
1817

1918
INPUT_FEATURE = (FeatureType.MASK_TIMELESS, "TEST")
2019
OUTPUT_FEATURE = (FeatureType.MASK_TIMELESS, "TEST_TRAIN_MASK")
20+
INPUT_FEATURE_CONFIG = PatchGeneratorConfig(raster_shape=(1000, 1000), depth_range=(3, 4))
2121

2222

2323
@pytest.mark.parametrize(
@@ -35,26 +35,14 @@ def test_bad_args(bad_arg: Any, bad_kwargs: Any) -> None:
3535
TrainTestSplitTask(INPUT_FEATURE, OUTPUT_FEATURE, bad_arg, **bad_kwargs)
3636

3737

38-
SEED = 1
39-
40-
4138
@pytest.fixture(name="eopatch1", scope="function")
4239
def eopatch1_fixture() -> EOPatch:
43-
eopatch = EOPatch(bbox=BBox((0, 0, 1, 1), CRS(3857)))
44-
45-
rng = np.random.default_rng(SEED)
46-
eopatch[INPUT_FEATURE] = rng.integers(0, 10, size=(1000, 1000, 3))
47-
48-
return eopatch
40+
return generate_eopatch(INPUT_FEATURE, config=INPUT_FEATURE_CONFIG)
4941

5042

5143
@pytest.fixture(name="eopatch2")
5244
def eopatch2_fixture() -> EOPatch:
53-
eopatch = EOPatch(bbox=BBox((0, 0, 1, 1), CRS(3857)))
54-
rng = np.random.default_rng(SEED)
55-
eopatch[INPUT_FEATURE] = rng.integers(0, 10, size=(1000, 1000, 3), dtype=int)
56-
57-
return eopatch
45+
return generate_eopatch(INPUT_FEATURE, seed=69, config=INPUT_FEATURE_CONFIG)
5846

5947

6048
def test_train_split(eopatch1: EOPatch) -> None:

0 commit comments

Comments
 (0)