15
15
import torch
16
16
from rasterio .crs import CRS
17
17
18
- from torchgeo .datasets import BoundingBox , DependencyNotFoundError
18
+ from torchgeo .datasets import BoundingBox , DependencyNotFoundError , Sample
19
19
from torchgeo .datasets .utils import (
20
20
Executable ,
21
21
array_to_tensor ,
@@ -381,13 +381,13 @@ def test_disambiguate_timestamp(
381
381
382
382
class TestCollateFunctionsMatchingKeys :
383
383
@pytest .fixture (scope = 'class' )
384
- def samples (self ) -> list [dict [ str , Any ] ]:
384
+ def samples (self ) -> list [Sample ]:
385
385
return [
386
386
{'image' : torch .tensor ([1 , 2 , 0 ]), 'crs' : CRS .from_epsg (2000 )},
387
387
{'image' : torch .tensor ([0 , 0 , 3 ]), 'crs' : CRS .from_epsg (2001 )},
388
388
]
389
389
390
- def test_stack_unbind_samples (self , samples : list [dict [ str , Any ] ]) -> None :
390
+ def test_stack_unbind_samples (self , samples : list [Sample ]) -> None :
391
391
sample = stack_samples (samples )
392
392
assert sample ['image' ].size () == torch .Size ([2 , 3 ])
393
393
assert torch .allclose (sample ['image' ], torch .tensor ([[1 , 2 , 0 ], [0 , 0 , 3 ]]))
@@ -398,13 +398,13 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
398
398
assert torch .allclose (samples [i ]['image' ], new_samples [i ]['image' ])
399
399
assert samples [i ]['crs' ] == new_samples [i ]['crs' ]
400
400
401
- def test_concat_samples (self , samples : list [dict [ str , Any ] ]) -> None :
401
+ def test_concat_samples (self , samples : list [Sample ]) -> None :
402
402
sample = concat_samples (samples )
403
403
assert sample ['image' ].size () == torch .Size ([6 ])
404
404
assert torch .allclose (sample ['image' ], torch .tensor ([1 , 2 , 0 , 0 , 0 , 3 ]))
405
405
assert sample ['crs' ] == CRS .from_epsg (2000 )
406
406
407
- def test_merge_samples (self , samples : list [dict [ str , Any ] ]) -> None :
407
+ def test_merge_samples (self , samples : list [Sample ]) -> None :
408
408
sample = merge_samples (samples )
409
409
assert sample ['image' ].size () == torch .Size ([3 ])
410
410
assert torch .allclose (sample ['image' ], torch .tensor ([1 , 2 , 3 ]))
@@ -413,13 +413,13 @@ def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
413
413
414
414
class TestCollateFunctionsDifferingKeys :
415
415
@pytest .fixture (scope = 'class' )
416
- def samples (self ) -> list [dict [ str , Any ] ]:
416
+ def samples (self ) -> list [Sample ]:
417
417
return [
418
418
{'image' : torch .tensor ([1 , 2 , 0 ]), 'crs1' : CRS .from_epsg (2000 )},
419
419
{'mask' : torch .tensor ([0 , 0 , 3 ]), 'crs2' : CRS .from_epsg (2001 )},
420
420
]
421
421
422
- def test_stack_unbind_samples (self , samples : list [dict [ str , Any ] ]) -> None :
422
+ def test_stack_unbind_samples (self , samples : list [Sample ]) -> None :
423
423
sample = stack_samples (samples )
424
424
assert sample ['image' ].size () == torch .Size ([1 , 3 ])
425
425
assert sample ['mask' ].size () == torch .Size ([1 , 3 ])
@@ -434,7 +434,7 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
434
434
assert torch .allclose (samples [1 ]['mask' ], new_samples [0 ]['mask' ])
435
435
assert samples [1 ]['crs2' ] == new_samples [0 ]['crs2' ]
436
436
437
- def test_concat_samples (self , samples : list [dict [ str , Any ] ]) -> None :
437
+ def test_concat_samples (self , samples : list [Sample ]) -> None :
438
438
sample = concat_samples (samples )
439
439
assert sample ['image' ].size () == torch .Size ([3 ])
440
440
assert sample ['mask' ].size () == torch .Size ([3 ])
@@ -443,7 +443,7 @@ def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
443
443
assert sample ['crs1' ] == CRS .from_epsg (2000 )
444
444
assert sample ['crs2' ] == CRS .from_epsg (2001 )
445
445
446
- def test_merge_samples (self , samples : list [dict [ str , Any ] ]) -> None :
446
+ def test_merge_samples (self , samples : list [Sample ]) -> None :
447
447
sample = merge_samples (samples )
448
448
assert sample ['image' ].size () == torch .Size ([3 ])
449
449
assert sample ['mask' ].size () == torch .Size ([3 ])
0 commit comments