Skip to content

Commit 4d7bbee

Browse files
Implement deterministic GeoDataset
1 parent 0e2c76d commit 4d7bbee

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

tests/datasets/test_geo.py

+24
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ class TestGeoDataset:
7676
def dataset(self) -> GeoDataset:
7777
return CustomGeoDataset()
7878

79+
@pytest.fixture(scope="class")
80+
def files(self) -> list[str]:
81+
"""Example list of files in the expected output order.
82+
83+
Used to test the titular property.
84+
"""
85+
files = ["file://file1.tif", "file://file2.tif", "file://file3.tif"]
86+
return files
87+
7988
def test_getitem(self, dataset: GeoDataset) -> None:
8089
query = BoundingBox(0, 1, 2, 3, 4, 5)
8190
assert dataset[query] == {"index": query}
@@ -177,6 +186,21 @@ def test_files_property_for_virtual_files(self) -> None:
177186
]
178187
assert len(CustomGeoDataset(paths=paths).files) == len(paths)
179188

189+
def test_files_property_ordered(self, files: list[str]) -> None:
190+
"""Ensure that the list of files is ordered."""
191+
paths = ["file://file3.tif", "file://file1.tif", "file://file2.tif"]
192+
assert CustomGeoDataset(paths=paths).files == files
193+
194+
def test_files_property_deterministic(self, files: list[str]) -> None:
195+
"""Ensure that the list of files is consistent regardless of their original
196+
order.
197+
"""
198+
paths1 = ["file://file3.tif", "file://file1.tif", "file://file2.tif"]
199+
paths2 = ["file://file2.tif", "file://file3.tif", "file://file1.tif"]
200+
assert (
201+
CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files
202+
)
203+
180204

181205
class TestRasterDataset:
182206
@pytest.fixture(params=zip([["R", "G", "B"], None], [True, False]))

torchgeo/datasets/geo.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def res(self, new_res: float) -> None:
287287
self._res = new_res
288288

289289
@property
290-
def files(self) -> set[str]:
290+
def files(self) -> list[str]:
291291
"""A list of all files in the dataset.
292292
293293
Returns:
@@ -316,7 +316,8 @@ def files(self) -> set[str]:
316316
UserWarning,
317317
)
318318

319-
return files
319+
# Sort the output to enforce deterministic behavior.
320+
return sorted(files)
320321

321322

322323
class RasterDataset(GeoDataset):

0 commit comments

Comments
 (0)