44
44
from eolearn .core .core_tasks import ExplodeBandsTask
45
45
from eolearn .core .types import FeatureRenameSpec , FeatureSpec , FeaturesSpecification
46
46
from eolearn .core .utils .parsing import parse_features
47
- from eolearn .core .utils .testing import assert_feature_data_equal
47
+ from eolearn .core .utils .testing import PatchGeneratorConfig , assert_feature_data_equal , generate_eopatch
48
48
49
49
DUMMY_BBOX = BBox ((0 , 0 , 1 , 1 ), CRS (3857 ))
50
50
51
51
52
52
@pytest .fixture (name = "patch" )
53
53
def patch_fixture () -> EOPatch :
54
- patch = EOPatch (bbox = BBox ((324.54 , 546.45 , 955.4 , 63.43 ), CRS (3857 )))
55
- patch .data ["bands" ] = np .arange (5 * 3 * 4 * 8 ).reshape (5 , 3 , 4 , 8 )
56
- patch .data ["CLP" ] = np .full ((5 , 3 , 4 , 1 ), 0.7 )
57
- patch .data ["CLP_S2C" ] = np .zeros ((5 , 3 , 4 , 1 ), dtype = np .int64 )
58
- patch .mask ["CLM" ] = np .full ((5 , 3 , 4 , 1 ), True )
59
- patch .mask_timeless ["mask" ] = np .arange (3 * 4 * 2 ).reshape (3 , 4 , 2 )
60
- patch .mask_timeless ["LULC" ] = np .zeros ((3 , 4 , 1 ), dtype = np .uint16 )
61
- patch .mask_timeless ["RANDOM_UINT8" ] = np .random .randint (0 , 100 , size = (3 , 4 , 1 ), dtype = np .int8 )
62
- patch .scalar ["values" ] = np .arange (10 * 5 ).reshape (5 , 10 )
63
- patch .scalar ["CLOUD_COVERAGE" ] = np .ones ((5 , 10 ))
64
- patch .timestamps = [
65
- datetime (2017 , 1 , 14 , 10 , 13 , 46 ),
66
- datetime (2017 , 2 , 10 , 10 , 1 , 32 ),
67
- datetime (2017 , 2 , 20 , 10 , 6 , 35 ),
68
- datetime (2017 , 3 , 2 , 10 , 0 , 20 ),
69
- datetime (2017 , 3 , 12 , 10 , 7 , 6 ),
70
- ]
71
- patch .meta_info ["something" ] = np .random .rand (10 , 1 )
54
+ patch = generate_eopatch (
55
+ {
56
+ FeatureType .DATA : ["bands" , "CLP" ],
57
+ FeatureType .MASK : ["CLM" ],
58
+ FeatureType .MASK_TIMELESS : ["mask" , "LULC" , "RANDOM_UINT8" ],
59
+ FeatureType .SCALAR : ["values" , "CLOUD_COVERAGE" ],
60
+ }
61
+ )
62
+ patch .data ["CLP_S2C" ] = np .zeros_like (patch .data ["CLP" ])
63
+
64
+ patch .meta_info ["something" ] = "beep boop"
72
65
return patch
73
66
74
67
68
+ @pytest .fixture (name = "eopatch_to_explode" )
69
+ def eopatch_to_explode_fixture () -> EOPatch :
70
+ return generate_eopatch ((FeatureType .DATA , "bands" ), config = PatchGeneratorConfig (depth_range = (8 , 9 )))
71
+
72
+
75
73
@pytest .mark .parametrize ("task" , [DeepCopyTask , CopyTask ])
76
74
def test_copy (task : Type [CopyTask ], patch : EOPatch ) -> None :
77
75
patch_copy = task ().execute (patch )
@@ -413,11 +411,11 @@ def kwargs_map(data, *, some=3, **kwargs) -> tuple:
413
411
],
414
412
)
415
413
def test_explode_bands (
416
- patch : EOPatch ,
414
+ eopatch_to_explode : EOPatch ,
417
415
feature : Tuple [FeatureType , str ],
418
416
task_input : Dict [Tuple [FeatureType , str ], Union [int , Iterable [int ]]],
419
417
) -> None :
420
- patch = ExplodeBandsTask (feature , task_input )(patch )
418
+ patch = ExplodeBandsTask (feature , task_input )(eopatch_to_explode )
421
419
assert all (new_feature in patch for new_feature in task_input )
422
420
423
421
for new_feature , bands in task_input .items ():
@@ -426,19 +424,23 @@ def test_explode_bands(
426
424
assert_array_equal (patch [new_feature ], patch [feature ][..., bands ])
427
425
428
426
429
- def test_extract_bands (patch : EOPatch ) -> None :
427
+ def test_extract_bands (eopatch_to_explode : EOPatch ) -> None :
430
428
bands = [2 , 4 , 6 ]
431
- patch = ExtractBandsTask ((FeatureType .DATA , "bands" ), (FeatureType .DATA , "EXTRACTED_BANDS" ), bands )(patch )
429
+ patch = ExtractBandsTask ((FeatureType .DATA , "bands" ), (FeatureType .DATA , "EXTRACTED_BANDS" ), bands )(
430
+ eopatch_to_explode
431
+ )
432
432
assert_array_equal (patch .data ["EXTRACTED_BANDS" ], patch .data ["bands" ][..., bands ])
433
433
434
434
patch .data ["EXTRACTED_BANDS" ][0 , 0 , 0 , 0 ] += 1
435
435
assert patch .data ["EXTRACTED_BANDS" ][0 , 0 , 0 , 0 ] != patch .data ["bands" ][0 , 0 , 0 , bands [0 ]]
436
436
437
437
438
- def test_extract_bands_fails (patch : EOPatch ) -> None :
438
+ def test_extract_bands_fails (eopatch_to_explode : EOPatch ) -> None :
439
439
with pytest .raises (ValueError ):
440
440
# fails because band 16 does not exist
441
- ExtractBandsTask ((FeatureType .DATA , "bands" ), (FeatureType .DATA , "EXTRACTED_BANDS" ), [2 , 4 , 16 ])(patch )
441
+ ExtractBandsTask ((FeatureType .DATA , "bands" ), (FeatureType .DATA , "EXTRACTED_BANDS" ), [2 , 4 , 16 ])(
442
+ eopatch_to_explode
443
+ )
442
444
443
445
444
446
@pytest .mark .parametrize (
0 commit comments