From 44ce0073b2e24a27b8b9575327cce3808f419303 Mon Sep 17 00:00:00 2001 From: Hitesh Tolani Date: Thu, 18 Jul 2024 15:46:41 +0530 Subject: [PATCH] Datasets: add support for pathlib.Path (#2173) * Added pathlib support for ``geo.py`` * Fixed failing ruff checks. * Fixed additional ruff fromating errors * Added complete ``pathlib`` support * Additional changes * Fixed ``cyclone.py`` issues * Additional fixes * Fixed ``isinstance`` and ``Path`` inconsistency and * Fixed mypy errors in ``cdl.py`` * geo/utils: all paths are Paths * datasets: all paths are Paths * Test Paths * Type checks only work for latest torchvision * Fix tests --------- Co-authored-by: Hitesh Tolani Co-authored-by: Adam J. Stewart --- .pre-commit-config.yaml | 1 + tests/datasets/test_advance.py | 6 +-- tests/datasets/test_agb_live_woody_density.py | 6 +-- tests/datasets/test_agrifieldnet.py | 2 +- tests/datasets/test_airphen.py | 2 +- tests/datasets/test_astergdem.py | 4 +- tests/datasets/test_benin_cashews.py | 4 +- tests/datasets/test_bigearthnet.py | 18 ++++----- tests/datasets/test_biomassters.py | 2 +- tests/datasets/test_cbf.py | 6 +-- tests/datasets/test_cdl.py | 10 ++--- tests/datasets/test_chabud.py | 8 ++-- tests/datasets/test_chesapeake.py | 14 +++---- tests/datasets/test_cloud_cover.py | 4 +- tests/datasets/test_cms_mangrove_canopy.py | 8 ++-- tests/datasets/test_cowc.py | 10 ++--- tests/datasets/test_cropharvest.py | 10 ++--- tests/datasets/test_cv4a_kenya_crop_type.py | 4 +- tests/datasets/test_cyclone.py | 4 +- tests/datasets/test_deepglobelandcover.py | 10 ++--- tests/datasets/test_dfc2022.py | 6 +-- tests/datasets/test_eddmaps.py | 2 +- tests/datasets/test_enviroatlas.py | 8 ++-- tests/datasets/test_esri2020.py | 8 ++-- tests/datasets/test_etci2021.py | 6 +-- tests/datasets/test_eudem.py | 8 ++-- tests/datasets/test_eurocrops.py | 6 +-- tests/datasets/test_eurosat.py | 14 +++---- tests/datasets/test_fair1m.py | 20 +++++----- tests/datasets/test_fire_risk.py | 12 +++--- tests/datasets/test_forestdamage.py | 10 ++--- tests/datasets/test_gbif.py | 2 +- tests/datasets/test_geo.py | 8 ++-- tests/datasets/test_gid15.py | 6 +-- tests/datasets/test_globbiomass.py | 6 +-- tests/datasets/test_idtrees.py | 8 ++-- tests/datasets/test_inaturalist.py | 2 +- tests/datasets/test_inria.py | 2 +- tests/datasets/test_iobench.py | 8 ++-- tests/datasets/test_l7irish.py | 8 ++-- tests/datasets/test_l8biome.py | 8 ++-- tests/datasets/test_landcoverai.py | 14 +++---- tests/datasets/test_landsat.py | 2 +- tests/datasets/test_levircd.py | 10 ++--- tests/datasets/test_loveda.py | 6 +-- tests/datasets/test_mapinwild.py | 14 +++---- tests/datasets/test_millionaid.py | 6 +-- tests/datasets/test_naip.py | 2 +- tests/datasets/test_nasa_marine_debris.py | 14 +++---- tests/datasets/test_nccm.py | 6 +-- tests/datasets/test_nlcd.py | 10 ++--- tests/datasets/test_openbuildings.py | 4 +- tests/datasets/test_oscd.py | 8 ++-- tests/datasets/test_pastis.py | 10 ++--- tests/datasets/test_patternnet.py | 12 +++--- tests/datasets/test_potsdam.py | 8 ++-- tests/datasets/test_prisma.py | 2 +- tests/datasets/test_quakeset.py | 8 ++-- tests/datasets/test_reforestree.py | 10 ++--- tests/datasets/test_resisc45.py | 12 +++--- tests/datasets/test_rwanda_field_boundary.py | 4 +- tests/datasets/test_seasonet.py | 10 +++-- tests/datasets/test_seco.py | 8 ++-- tests/datasets/test_sen12ms.py | 4 +- tests/datasets/test_sentinel.py | 4 +- tests/datasets/test_skippd.py | 8 ++-- tests/datasets/test_so2sat.py | 2 +- tests/datasets/test_south_africa_crop_type.py | 2 +- tests/datasets/test_south_america_soybean.py | 8 ++-- tests/datasets/test_spacenet.py | 26 ++++++------- tests/datasets/test_ssl4eo.py | 12 +++--- tests/datasets/test_ssl4eo_benchmark.py | 8 ++-- .../datasets/test_sustainbench_crop_yield.py | 8 ++-- tests/datasets/test_ucmerced.py | 12 +++--- tests/datasets/test_usavars.py | 8 ++-- tests/datasets/test_utils.py | 11 +++--- tests/datasets/test_vaihingen.py | 8 ++-- tests/datasets/test_vhr10.py | 6 +-- .../test_western_usa_live_fuel_moisture.py | 6 +-- tests/datasets/test_xview2.py | 6 +-- tests/datasets/test_zuericrop.py | 6 +-- torchgeo/datasets/advance.py | 10 ++--- torchgeo/datasets/agb_live_woody_density.py | 7 ++-- torchgeo/datasets/agrifieldnet.py | 9 +++-- torchgeo/datasets/astergdem.py | 3 +- torchgeo/datasets/benin_cashews.py | 4 +- torchgeo/datasets/bigearthnet.py | 10 ++--- torchgeo/datasets/biomassters.py | 8 ++-- torchgeo/datasets/cbf.py | 9 +++-- torchgeo/datasets/cdl.py | 9 +++-- torchgeo/datasets/chabud.py | 4 +- torchgeo/datasets/chesapeake.py | 13 ++++--- torchgeo/datasets/cloud_cover.py | 4 +- torchgeo/datasets/cms_mangrove_canopy.py | 9 +++-- torchgeo/datasets/cowc.py | 4 +- torchgeo/datasets/cropharvest.py | 10 ++--- torchgeo/datasets/cv4a_kenya_crop_type.py | 4 +- torchgeo/datasets/cyclone.py | 4 +- torchgeo/datasets/deepglobelandcover.py | 3 +- torchgeo/datasets/dfc2022.py | 8 ++-- torchgeo/datasets/eddmaps.py | 4 +- torchgeo/datasets/enviroatlas.py | 6 +-- torchgeo/datasets/esri2020.py | 9 +++-- torchgeo/datasets/etci2021.py | 10 ++--- torchgeo/datasets/eudem.py | 7 ++-- torchgeo/datasets/eurocrops.py | 11 +++--- torchgeo/datasets/eurosat.py | 6 +-- torchgeo/datasets/fair1m.py | 10 ++--- torchgeo/datasets/fire_risk.py | 4 +- torchgeo/datasets/forestdamage.py | 10 ++--- torchgeo/datasets/gbif.py | 4 +- torchgeo/datasets/geo.py | 32 ++++++++-------- torchgeo/datasets/gid15.py | 10 ++--- torchgeo/datasets/globbiomass.py | 17 ++++++--- torchgeo/datasets/idtrees.py | 18 ++++----- torchgeo/datasets/inaturalist.py | 4 +- torchgeo/datasets/inria.py | 10 ++--- torchgeo/datasets/iobench.py | 4 +- torchgeo/datasets/l7irish.py | 19 +++++++--- torchgeo/datasets/l8biome.py | 9 +++-- torchgeo/datasets/landcoverai.py | 14 ++++--- torchgeo/datasets/landsat.py | 3 +- torchgeo/datasets/levircd.py | 14 +++---- torchgeo/datasets/loveda.py | 8 ++-- torchgeo/datasets/mapinwild.py | 7 ++-- torchgeo/datasets/millionaid.py | 8 ++-- torchgeo/datasets/nasa_marine_debris.py | 13 +++++-- torchgeo/datasets/nccm.py | 4 +- torchgeo/datasets/nlcd.py | 9 +++-- torchgeo/datasets/openbuildings.py | 13 ++++--- torchgeo/datasets/oscd.py | 7 ++-- torchgeo/datasets/pastis.py | 4 +- torchgeo/datasets/patternnet.py | 4 +- torchgeo/datasets/potsdam.py | 3 +- torchgeo/datasets/quakeset.py | 4 +- torchgeo/datasets/reforestree.py | 10 ++--- torchgeo/datasets/resisc45.py | 6 +-- torchgeo/datasets/rwanda_field_boundary.py | 4 +- torchgeo/datasets/seasonet.py | 4 +- torchgeo/datasets/seco.py | 6 +-- torchgeo/datasets/sen12ms.py | 4 +- torchgeo/datasets/sentinel.py | 5 ++- torchgeo/datasets/skippd.py | 4 +- torchgeo/datasets/so2sat.py | 4 +- torchgeo/datasets/south_africa_crop_type.py | 9 +++-- torchgeo/datasets/south_america_soybean.py | 7 ++-- torchgeo/datasets/spacenet.py | 29 +++++++------- torchgeo/datasets/ssl4eo.py | 6 +-- torchgeo/datasets/ssl4eo_benchmark.py | 8 ++-- torchgeo/datasets/sustainbench_crop_yield.py | 4 +- torchgeo/datasets/ucmerced.py | 6 +-- torchgeo/datasets/usavars.py | 6 +-- torchgeo/datasets/utils.py | 38 ++++++++++--------- torchgeo/datasets/vaihingen.py | 3 +- torchgeo/datasets/vhr10.py | 3 +- .../western_usa_live_fuel_moisture.py | 4 +- torchgeo/datasets/xview.py | 15 +++++--- torchgeo/datasets/zuericrop.py | 4 +- 158 files changed, 658 insertions(+), 598 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9c3cb56eca..b6cbc81a5e7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,6 +32,7 @@ repos: - scikit-image>=0.22.0 - torch>=2.3 - torchmetrics>=0.10 + - torchvision>=0.18 exclude: (build|data|dist|logo|logs|output)/ - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index f2a34b89f4c..077837f8938 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -17,7 +17,7 @@ pytest.importorskip('scipy', minversion='1.7.2') -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -33,7 +33,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE: md5s = ['43acacecebecd17a82bc2c1e719fd7e4', '039b7baa47879a8a4e32b9dd8287f6ad'] monkeypatch.setattr(ADVANCE, 'urls', urls) monkeypatch.setattr(ADVANCE, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ADVANCE(root, transforms, download=True, checksum=True) @@ -57,7 +57,7 @@ def test_already_downloaded(self, dataset: ADVANCE) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ADVANCE(str(tmp_path)) + ADVANCE(tmp_path) def test_plot(self, dataset: ADVANCE) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py index 3b4c0636c1d..cfa0ecd112f 100644 --- a/tests/datasets/test_agb_live_woody_density.py +++ b/tests/datasets/test_agb_live_woody_density.py @@ -21,7 +21,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -42,7 +42,7 @@ def dataset( ) monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, 'url', url) - root = str(tmp_path) + root = tmp_path return AbovegroundLiveWoodyBiomassDensity( root, transforms=transforms, download=True ) @@ -58,7 +58,7 @@ def test_len(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: def test_no_dataset(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - AbovegroundLiveWoodyBiomassDensity(str(tmp_path)) + AbovegroundLiveWoodyBiomassDensity(tmp_path) def test_already_downloaded( self, dataset: AbovegroundLiveWoodyBiomassDensity diff --git a/tests/datasets/test_agrifieldnet.py b/tests/datasets/test_agrifieldnet.py index 6608dc7a1bb..d59776d0cdd 100644 --- a/tests/datasets/test_agrifieldnet.py +++ b/tests/datasets/test_agrifieldnet.py @@ -50,7 +50,7 @@ def test_already_downloaded(self, dataset: AgriFieldNet) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - AgriFieldNet(str(tmp_path)) + AgriFieldNet(tmp_path) def test_plot(self, dataset: AgriFieldNet) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_airphen.py b/tests/datasets/test_airphen.py index 3c60fb090f7..9b3618f9e8b 100644 --- a/tests/datasets/test_airphen.py +++ b/tests/datasets/test_airphen.py @@ -52,7 +52,7 @@ def test_plot(self, dataset: Airphen) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Airphen(str(tmp_path)) + Airphen(tmp_path) def test_invalid_query(self, dataset: Airphen) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_astergdem.py b/tests/datasets/test_astergdem.py index 7f1aeaa4cd6..abcf822eb4e 100644 --- a/tests/datasets/test_astergdem.py +++ b/tests/datasets/test_astergdem.py @@ -25,7 +25,7 @@ class TestAsterGDEM: def dataset(self, tmp_path: Path) -> AsterGDEM: zipfile = os.path.join('tests', 'data', 'astergdem', 'astergdem.zip') shutil.unpack_archive(zipfile, tmp_path, 'zip') - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return AsterGDEM(root, transforms=transforms) @@ -33,7 +33,7 @@ def test_datasetmissing(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - AsterGDEM(str(tmp_path)) + AsterGDEM(tmp_path) def test_getitem(self, dataset: AsterGDEM) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index cc06c0060a0..795e5c6d521 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -29,7 +29,7 @@ def dataset( monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('20191105',)) monkeypatch.setattr(BeninSmallHolderCashews, 'tile_height', 2) monkeypatch.setattr(BeninSmallHolderCashews, 'tile_width', 2) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return BeninSmallHolderCashews(root, transforms=transforms, download=True) @@ -54,7 +54,7 @@ def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - BeninSmallHolderCashews(str(tmp_path)) + BeninSmallHolderCashews(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 82a3655626f..dcb4c87d887 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -16,7 +16,7 @@ from torchgeo.datasets import BigEarthNet, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -63,7 +63,7 @@ def dataset( monkeypatch.setattr(BigEarthNet, 'metadata', metadata) monkeypatch.setattr(BigEarthNet, 'splits_metadata', splits_metadata) bands, num_classes, split = request.param - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return BigEarthNet( root, split, bands, num_classes, transforms, download=True, checksum=True @@ -95,7 +95,7 @@ def test_len(self, dataset: BigEarthNet) -> None: def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: BigEarthNet( - root=str(tmp_path), + root=tmp_path, bands=dataset.bands, split=dataset.split, num_classes=dataset.num_classes, @@ -112,21 +112,21 @@ def test_already_downloaded_not_extracted( shutil.rmtree( os.path.join(dataset.root, dataset.metadata['s2']['directory']) ) - download_url(dataset.metadata['s1']['url'], root=str(tmp_path)) - download_url(dataset.metadata['s2']['url'], root=str(tmp_path)) + download_url(dataset.metadata['s1']['url'], root=tmp_path) + download_url(dataset.metadata['s2']['url'], root=tmp_path) elif dataset.bands == 's1': shutil.rmtree( os.path.join(dataset.root, dataset.metadata['s1']['directory']) ) - download_url(dataset.metadata['s1']['url'], root=str(tmp_path)) + download_url(dataset.metadata['s1']['url'], root=tmp_path) else: shutil.rmtree( os.path.join(dataset.root, dataset.metadata['s2']['directory']) ) - download_url(dataset.metadata['s2']['url'], root=str(tmp_path)) + download_url(dataset.metadata['s2']['url'], root=tmp_path) BigEarthNet( - root=str(tmp_path), + root=tmp_path, bands=dataset.bands, split=dataset.split, num_classes=dataset.num_classes, @@ -135,7 +135,7 @@ def test_already_downloaded_not_extracted( def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - BigEarthNet(str(tmp_path)) + BigEarthNet(tmp_path) def test_plot(self, dataset: BigEarthNet) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_biomassters.py b/tests/datasets/test_biomassters.py index 8a853145da8..f9ea246ae73 100644 --- a/tests/datasets/test_biomassters.py +++ b/tests/datasets/test_biomassters.py @@ -37,7 +37,7 @@ def test_invalid_bands(self, dataset: BioMassters) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - BioMassters(str(tmp_path)) + BioMassters(tmp_path) def test_plot(self, dataset: BioMassters) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py index 4287cd9673d..821e41d08a8 100644 --- a/tests/datasets/test_cbf.py +++ b/tests/datasets/test_cbf.py @@ -22,7 +22,7 @@ ) -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -41,7 +41,7 @@ def dataset( url = os.path.join('tests', 'data', 'cbf') + os.sep monkeypatch.setattr(CanadianBuildingFootprints, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return CanadianBuildingFootprints( root, res=0.1, transforms=transforms, download=True, checksum=True @@ -80,7 +80,7 @@ def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CanadianBuildingFootprints(str(tmp_path)) + CanadianBuildingFootprints(tmp_path) def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: query = BoundingBox(2, 2, 2, 2, 2, 2) diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 19ae64514c2..0c72d97c47d 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -24,7 +24,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -41,7 +41,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL: url = os.path.join('tests', 'data', 'cdl', '{}_30m_cdls.zip') monkeypatch.setattr(CDL, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return CDL( root, @@ -87,7 +87,7 @@ def test_already_extracted(self, dataset: CDL) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'cdl', '*_30m_cdls.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) CDL(root, years=[2023, 2022]) @@ -97,7 +97,7 @@ def test_invalid_year(self, tmp_path: Path) -> None: AssertionError, match='CDL data product only exists for the following years:', ): - CDL(str(tmp_path), years=[1996]) + CDL(tmp_path, years=[1996]) def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): @@ -121,7 +121,7 @@ def test_plot_prediction(self, dataset: CDL) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CDL(str(tmp_path)) + CDL(tmp_path) def test_invalid_query(self, dataset: CDL) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py index 074674a1733..2db60e057b0 100644 --- a/tests/datasets/test_chabud.py +++ b/tests/datasets/test_chabud.py @@ -18,7 +18,9 @@ pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: +def download_url( + url: str, root: str | Path, filename: str, *args: str, **kwargs: str +) -> None: shutil.copy(url, os.path.join(root, filename)) @@ -34,7 +36,7 @@ def dataset( monkeypatch.setattr(ChaBuD, 'url', url) monkeypatch.setattr(ChaBuD, 'md5', md5) bands, split = request.param - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ChaBuD( root=root, @@ -70,7 +72,7 @@ def test_already_downloaded(self, dataset: ChaBuD) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ChaBuD(str(tmp_path)) + ChaBuD(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 814c6997d32..33dbfd27978 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -26,7 +26,7 @@ pytest.importorskip('zipfile_deflate64') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -41,7 +41,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13: ) monkeypatch.setattr(Chesapeake13, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return Chesapeake13(root, transforms=transforms, download=True, checksum=True) @@ -69,13 +69,13 @@ def test_already_downloaded(self, tmp_path: Path) -> None: url = os.path.join( 'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip' ) - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) Chesapeake13(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Chesapeake13(str(tmp_path), checksum=True) + Chesapeake13(tmp_path, checksum=True) def test_plot(self, dataset: Chesapeake13) -> None: query = dataset.bounds @@ -148,7 +148,7 @@ def dataset( '_files', ['de_1m_2013_extended-debuffered-test_tiles', 'spatial_index.geojson'], ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ChesapeakeCVPR( root, @@ -180,7 +180,7 @@ def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None: ChesapeakeCVPR(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - root = str(tmp_path) + root = tmp_path shutil.copy( os.path.join( 'tests', 'data', 'chesapeake', 'cvpr', 'cvpr_chesapeake_landcover.zip' @@ -201,7 +201,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ChesapeakeCVPR(str(tmp_path), checksum=True) + ChesapeakeCVPR(tmp_path, checksum=True) def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py index dae87cf3633..c2ed31bf108 100644 --- a/tests/datasets/test_cloud_cover.py +++ b/tests/datasets/test_cloud_cover.py @@ -30,7 +30,7 @@ def dataset( ) -> CloudCoverDetection: url = os.path.join('tests', 'data', 'ref_cloud_cover_detection_challenge_v1') monkeypatch.setattr(CloudCoverDetection, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return CloudCoverDetection( @@ -55,7 +55,7 @@ def test_already_downloaded(self, dataset: CloudCoverDetection) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CloudCoverDetection(str(tmp_path)) + CloudCoverDetection(tmp_path) def test_plot(self, dataset: CloudCoverDetection) -> None: sample = dataset[0] diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py index ce12795a07e..dfa84d0f9fc 100644 --- a/tests/datasets/test_cms_mangrove_canopy.py +++ b/tests/datasets/test_cms_mangrove_canopy.py @@ -20,7 +20,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -54,7 +54,7 @@ def test_len(self, dataset: CMSGlobalMangroveCanopy) -> None: def test_no_dataset(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CMSGlobalMangroveCanopy(str(tmp_path)) + CMSGlobalMangroveCanopy(tmp_path) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( @@ -63,7 +63,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: 'cms_mangrove_canopy', 'CMS_Global_Map_Mangrove_Canopy_1665.zip', ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) CMSGlobalMangroveCanopy(root, country='Angola') @@ -73,7 +73,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ) as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - CMSGlobalMangroveCanopy(str(tmp_path), country='Angola', checksum=True) + CMSGlobalMangroveCanopy(tmp_path, country='Angola', checksum=True) def test_invalid_country(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index f454569d5b7..00c4aecf50b 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -17,7 +17,7 @@ from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -46,7 +46,7 @@ def dataset( '0a4daed8c5f6c4e20faa6e38636e4346', ] monkeypatch.setattr(COWCCounting, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return COWCCounting(root, split, transforms, download=True, checksum=True) @@ -78,7 +78,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - COWCCounting(str(tmp_path)) + COWCCounting(tmp_path) def test_plot(self, dataset: COWCCounting) -> None: x = dataset[0].copy() @@ -110,7 +110,7 @@ def dataset( 'dccc2257e9c4a9dde2b4f84769804046', ] monkeypatch.setattr(COWCDetection, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return COWCDetection(root, split, transforms, download=True, checksum=True) @@ -142,7 +142,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - COWCDetection(str(tmp_path)) + COWCDetection(tmp_path) def test_plot(self, dataset: COWCDetection) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py index 2ad82fca137..82ab9f59f12 100644 --- a/tests/datasets/test_cropharvest.py +++ b/tests/datasets/test_cropharvest.py @@ -17,7 +17,7 @@ pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, filename: str, md5: str) -> None: +def download_url(url: str, root: str | Path, filename: str, md5: str) -> None: shutil.copy(url, os.path.join(root, filename)) @@ -42,7 +42,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: os.path.join('tests', 'data', 'cropharvest', 'labels.geojson'), ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() dataset = CropHarvest(root, transforms, download=True, checksum=True) @@ -61,16 +61,16 @@ def test_len(self, dataset: CropHarvest) -> None: assert len(dataset) == 5 def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None: - CropHarvest(root=str(tmp_path), download=False) + CropHarvest(root=tmp_path, download=False) def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None: feature_path = os.path.join(tmp_path, 'features') shutil.rmtree(feature_path) - CropHarvest(root=str(tmp_path), download=True) + CropHarvest(root=tmp_path, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CropHarvest(str(tmp_path)) + CropHarvest(tmp_path) def test_plot(self, dataset: CropHarvest) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index 34f67036d2a..e6309844054 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -30,7 +30,7 @@ def dataset( monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606']) monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2) monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return CV4AKenyaCropType(root, transforms=transforms, download=True) @@ -55,7 +55,7 @@ def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CV4AKenyaCropType(str(tmp_path)) + CV4AKenyaCropType(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index bb18bed06ca..8dfd39b3c9f 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -28,7 +28,7 @@ def dataset( url = os.path.join('tests', 'data', 'cyclone') monkeypatch.setattr(TropicalCyclone, 'url', url) monkeypatch.setattr(TropicalCyclone, 'size', 2) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return TropicalCyclone(root, split, transforms, download=True) @@ -60,7 +60,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - TropicalCyclone(str(tmp_path)) + TropicalCyclone(tmp_path) def test_plot(self, dataset: TropicalCyclone) -> None: sample = dataset[0] diff --git a/tests/datasets/test_deepglobelandcover.py b/tests/datasets/test_deepglobelandcover.py index 5e845958668..2ea779a98fa 100644 --- a/tests/datasets/test_deepglobelandcover.py +++ b/tests/datasets/test_deepglobelandcover.py @@ -39,16 +39,14 @@ def test_len(self, dataset: DeepGlobeLandCover) -> None: def test_extract(self, tmp_path: Path) -> None: root = os.path.join('tests', 'data', 'deepglobelandcover') filename = 'data.zip' - shutil.copyfile( - os.path.join(root, filename), os.path.join(str(tmp_path), filename) - ) - DeepGlobeLandCover(root=str(tmp_path)) + shutil.copyfile(os.path.join(root, filename), os.path.join(tmp_path, filename)) + DeepGlobeLandCover(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'data.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - DeepGlobeLandCover(root=str(tmp_path), checksum=True) + DeepGlobeLandCover(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -56,7 +54,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - DeepGlobeLandCover(str(tmp_path)) + DeepGlobeLandCover(tmp_path) def test_plot(self, dataset: DeepGlobeLandCover) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py index d353da5e274..4d40aa2e442 100644 --- a/tests/datasets/test_dfc2022.py +++ b/tests/datasets/test_dfc2022.py @@ -61,13 +61,13 @@ def test_extract(self, tmp_path: Path) -> None: os.path.join('tests', 'data', 'dfc2022', 'val.zip'), os.path.join(tmp_path, 'val.zip'), ) - DFC2022(root=str(tmp_path)) + DFC2022(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'labeled_train.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - DFC2022(root=str(tmp_path), checksum=True) + DFC2022(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -75,7 +75,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - DFC2022(str(tmp_path)) + DFC2022(tmp_path) def test_plot(self, dataset: DFC2022) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_eddmaps.py b/tests/datasets/test_eddmaps.py index 364e988aba3..1a1e805e13f 100644 --- a/tests/datasets/test_eddmaps.py +++ b/tests/datasets/test_eddmaps.py @@ -38,7 +38,7 @@ def test_or(self, dataset: EDDMapS) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EDDMapS(str(tmp_path)) + EDDMapS(tmp_path) def test_invalid_query(self, dataset: EDDMapS) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py index 11ac3b93436..457698f4f85 100644 --- a/tests/datasets/test_enviroatlas.py +++ b/tests/datasets/test_enviroatlas.py @@ -24,7 +24,7 @@ from torchgeo.samplers import RandomGeoSampler -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -51,7 +51,7 @@ def dataset( '_files', ['pittsburgh_pa-2010_1m-train_tiles-debuffered', 'spatial_index.geojson'], ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return EnviroAtlas( root, @@ -85,7 +85,7 @@ def test_already_extracted(self, dataset: EnviroAtlas) -> None: EnviroAtlas(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - root = str(tmp_path) + root = tmp_path shutil.copy( os.path.join('tests', 'data', 'enviroatlas', 'enviroatlas_lotp.zip'), root ) @@ -93,7 +93,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EnviroAtlas(str(tmp_path), checksum=True) + EnviroAtlas(tmp_path, checksum=True) def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py index 3fe1207b6b4..917caea86ed 100644 --- a/tests/datasets/test_esri2020.py +++ b/tests/datasets/test_esri2020.py @@ -22,7 +22,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -42,7 +42,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Esri2020: 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip', ) monkeypatch.setattr(Esri2020, 'url', url) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return Esri2020(root, transforms=transforms, download=True, checksum=True) @@ -66,11 +66,11 @@ def test_not_extracted(self, tmp_path: Path) -> None: 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip', ) shutil.copy(url, tmp_path) - Esri2020(str(tmp_path)) + Esri2020(tmp_path) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Esri2020(str(tmp_path), checksum=True) + Esri2020(tmp_path, checksum=True) def test_and(self, dataset: Esri2020) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index 0cf4029921d..69a417d2300 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -16,7 +16,7 @@ from torchgeo.datasets import ETCI2021, DatasetNotFoundError -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -48,7 +48,7 @@ def dataset( }, } monkeypatch.setattr(ETCI2021, 'metadata', metadata) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return ETCI2021(root, split, transforms, download=True, checksum=True) @@ -78,7 +78,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ETCI2021(str(tmp_path)) + ETCI2021(tmp_path) def test_plot(self, dataset: ETCI2021) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_eudem.py b/tests/datasets/test_eudem.py index e984dd9e079..c41d36c8301 100644 --- a/tests/datasets/test_eudem.py +++ b/tests/datasets/test_eudem.py @@ -28,7 +28,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> EUDEM: monkeypatch.setattr(EUDEM, 'md5s', md5s) zipfile = os.path.join('tests', 'data', 'eudem', 'eu_dem_v11_E30N10.zip') shutil.copy(zipfile, tmp_path) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return EUDEM(root, transforms=transforms) @@ -42,7 +42,7 @@ def test_len(self, dataset: EUDEM) -> None: assert len(dataset) == 1 def test_extracted_already(self, dataset: EUDEM) -> None: - assert isinstance(dataset.paths, str) + assert isinstance(dataset.paths, Path) zipfile = os.path.join(dataset.paths, 'eu_dem_v11_E30N10.zip') shutil.unpack_archive(zipfile, dataset.paths, 'zip') EUDEM(dataset.paths) @@ -51,13 +51,13 @@ def test_no_dataset(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EUDEM(str(tmp_path)) + EUDEM(tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'eu_dem_v11_E30N10.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - EUDEM(str(tmp_path), checksum=True) + EUDEM(tmp_path, checksum=True) def test_and(self, dataset: EUDEM) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py index e716bbb783a..8374b823441 100644 --- a/tests/datasets/test_eurocrops.py +++ b/tests/datasets/test_eurocrops.py @@ -23,7 +23,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -42,7 +42,7 @@ def dataset( base_url = os.path.join('tests', 'data', 'eurocrops') + os.sep monkeypatch.setattr(EuroCrops, 'base_url', base_url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return EuroCrops( root, classes=classes, transforms=transforms, download=True, checksum=True @@ -81,7 +81,7 @@ def test_plot_prediction(self, dataset: EuroCrops) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EuroCrops(str(tmp_path)) + EuroCrops(tmp_path) def test_invalid_query(self, dataset: EuroCrops) -> None: query = BoundingBox(200, 200, 200, 200, 2, 2) diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 99841c26031..5dd2bb849a2 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -24,7 +24,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -61,7 +61,7 @@ def dataset( 'test': '4af60a00fdfdf8500572ae5360694b71', }, ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return base_class( root=root, split=split, transforms=transforms, download=True, checksum=True @@ -90,18 +90,18 @@ def test_add(self, dataset: EuroSAT) -> None: assert len(ds) == 4 def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None: - EuroSAT(root=str(tmp_path), download=True) + EuroSAT(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: EuroSAT, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - EuroSAT(root=str(tmp_path), download=False) + download_url(dataset.url, root=tmp_path) + EuroSAT(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EuroSAT(str(tmp_path)) + EuroSAT(tmp_path) def test_plot(self, dataset: EuroSAT) -> None: x = dataset[0].copy() @@ -114,7 +114,7 @@ def test_plot(self, dataset: EuroSAT) -> None: plt.close() def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None: - dataset = EuroSAT(root=str(tmp_path), bands=('B03',)) + dataset = EuroSAT(root=tmp_path, bands=('B03',)) with pytest.raises( RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index 38db23974d3..5110a7933ae 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -16,7 +16,9 @@ from torchgeo.datasets import FAIR1M, DatasetNotFoundError -def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: +def download_url( + url: str, root: str | Path, filename: str, *args: str, **kwargs: str +) -> None: os.makedirs(root, exist_ok=True) shutil.copy(url, os.path.join(root, filename)) @@ -65,7 +67,7 @@ def dataset( } monkeypatch.setattr(FAIR1M, 'urls', urls) monkeypatch.setattr(FAIR1M, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return FAIR1M(root, split, transforms, download=True, checksum=True) @@ -89,7 +91,7 @@ def test_len(self, dataset: FAIR1M) -> None: assert len(dataset) == 4 def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None: - FAIR1M(root=str(tmp_path), split=dataset.split, download=True) + FAIR1M(root=tmp_path, split=dataset.split, download=True) def test_already_downloaded_not_extracted( self, dataset: FAIR1M, tmp_path: Path @@ -98,11 +100,11 @@ def test_already_downloaded_not_extracted( for filepath, url in zip( dataset.paths[dataset.split], dataset.urls[dataset.split] ): - output = os.path.join(str(tmp_path), filepath) + output = os.path.join(tmp_path, filepath) os.makedirs(os.path.dirname(output), exist_ok=True) download_url(url, root=os.path.dirname(output), filename=output) - FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True) + FAIR1M(root=tmp_path, split=dataset.split, checksum=True) def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: md5s = tuple(['randomhash'] * len(FAIR1M.md5s[dataset.split])) @@ -111,17 +113,17 @@ def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: for filepath, url in zip( dataset.paths[dataset.split], dataset.urls[dataset.split] ): - output = os.path.join(str(tmp_path), filepath) + output = os.path.join(tmp_path, filepath) os.makedirs(os.path.dirname(output), exist_ok=True) download_url(url, root=os.path.dirname(output), filename=output) with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True) + FAIR1M(root=tmp_path, split=dataset.split, checksum=True) def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None: - shutil.rmtree(str(tmp_path)) + shutil.rmtree(tmp_path) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - FAIR1M(root=str(tmp_path), split=dataset.split) + FAIR1M(root=tmp_path, split=dataset.split) def test_plot(self, dataset: FAIR1M) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_fire_risk.py b/tests/datasets/test_fire_risk.py index e3f235c464d..301bbe5cc08 100644 --- a/tests/datasets/test_fire_risk.py +++ b/tests/datasets/test_fire_risk.py @@ -16,7 +16,7 @@ from torchgeo.datasets import DatasetNotFoundError, FireRisk -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -30,7 +30,7 @@ def dataset( md5 = 'db22106d61b10d855234b4a74db921ac' monkeypatch.setattr(FireRisk, 'md5', md5) monkeypatch.setattr(FireRisk, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return FireRisk(root, split, transforms, download=True, checksum=True) @@ -46,18 +46,18 @@ def test_len(self, dataset: FireRisk) -> None: assert len(dataset) == 5 def test_already_downloaded(self, dataset: FireRisk, tmp_path: Path) -> None: - FireRisk(root=str(tmp_path), download=True) + FireRisk(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: FireRisk, tmp_path: Path ) -> None: shutil.rmtree(os.path.dirname(dataset.root)) - download_url(dataset.url, root=str(tmp_path)) - FireRisk(root=str(tmp_path), download=False) + download_url(dataset.url, root=tmp_path) + FireRisk(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - FireRisk(str(tmp_path)) + FireRisk(tmp_path) def test_plot(self, dataset: FireRisk) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 39aae73026a..e760b8ad3f1 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -15,7 +15,7 @@ from torchgeo.datasets import DatasetNotFoundError, ForestDamage -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -31,7 +31,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ForestDamage: monkeypatch.setattr(ForestDamage, 'url', url) monkeypatch.setattr(ForestDamage, 'md5', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ForestDamage( root=root, transforms=transforms, download=True, checksum=True @@ -57,17 +57,17 @@ def test_not_extracted(self, tmp_path: Path) -> None: 'tests', 'data', 'forestdamage', 'Data_Set_Larch_Casebearer.zip' ) shutil.copy(url, tmp_path) - ForestDamage(root=str(tmp_path)) + ForestDamage(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'Data_Set_Larch_Casebearer.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - ForestDamage(root=str(tmp_path), checksum=True) + ForestDamage(root=tmp_path, checksum=True) def test_not_found(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ForestDamage(str(tmp_path)) + ForestDamage(tmp_path) def test_plot(self, dataset: ForestDamage) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_gbif.py b/tests/datasets/test_gbif.py index 35426d18b03..8c64d614c30 100644 --- a/tests/datasets/test_gbif.py +++ b/tests/datasets/test_gbif.py @@ -38,7 +38,7 @@ def test_or(self, dataset: GBIF) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - GBIF(str(tmp_path)) + GBIF(tmp_path) def test_invalid_query(self, dataset: GBIF) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index e3b11e7fc2a..25aee445359 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -36,7 +36,7 @@ def __init__( bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(4087), res: float = 1, - paths: str | Iterable[str] | None = None, + paths: str | Path | Iterable[str | Path] | None = None, ) -> None: super().__init__() self.index.insert(0, tuple(bounds)) @@ -172,7 +172,7 @@ def test_and_nongeo(self, dataset: GeoDataset) -> None: dataset & ds2 # type: ignore[operator] def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None: - paths = [str(tmp_path), str(tmp_path / 'non_existing_file.tif')] + paths = [tmp_path, tmp_path / 'non_existing_file.tif'] with pytest.warns(UserWarning, match='Path was ignored.'): assert len(CustomGeoDataset(paths=paths).files) == 0 @@ -311,7 +311,7 @@ def test_invalid_query(self, sentinel: Sentinel2) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - RasterDataset(str(tmp_path)) + RasterDataset(tmp_path) def test_no_all_bands(self) -> None: root = os.path.join('tests', 'data', 'sentinel2') @@ -380,7 +380,7 @@ def test_invalid_query(self, dataset: CustomVectorDataset) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - VectorDataset(str(tmp_path)) + VectorDataset(tmp_path) class TestNonGeoDataset: diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index 9c0358fb08b..b9cce571884 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -16,7 +16,7 @@ from torchgeo.datasets import GID15, DatasetNotFoundError -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -30,7 +30,7 @@ def dataset( monkeypatch.setattr(GID15, 'md5', md5) url = os.path.join('tests', 'data', 'gid15', 'gid-15.zip') monkeypatch.setattr(GID15, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return GID15(root, split, transforms, download=True, checksum=True) @@ -59,7 +59,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - GID15(str(tmp_path)) + GID15(tmp_path) def test_plot(self, dataset: GID15) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_globbiomass.py b/tests/datasets/test_globbiomass.py index 2e31b7b2222..5940b7113fd 100644 --- a/tests/datasets/test_globbiomass.py +++ b/tests/datasets/test_globbiomass.py @@ -37,7 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> GlobBiomass: } monkeypatch.setattr(GlobBiomass, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return GlobBiomass(root, transforms=transforms, checksum=True) @@ -55,13 +55,13 @@ def test_already_extracted(self, dataset: GlobBiomass) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - GlobBiomass(str(tmp_path), checksum=True) + GlobBiomass(tmp_path, checksum=True) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'N00E020_agb.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - GlobBiomass(str(tmp_path), checksum=True) + GlobBiomass(tmp_path, checksum=True) def test_and(self, dataset: GlobBiomass) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index a4c05580b58..0335c50bdfd 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -19,7 +19,7 @@ pytest.importorskip('laspy', minversion='2') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -44,7 +44,7 @@ def dataset( } split, task = request.param monkeypatch.setattr(IDTReeS, 'metadata', metadata) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return IDTReeS(root, split, task, transforms, download=True, checksum=True) @@ -77,11 +77,11 @@ def test_already_downloaded(self, dataset: IDTReeS) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - IDTReeS(str(tmp_path)) + IDTReeS(tmp_path) def test_not_extracted(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'idtrees', '*.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) IDTReeS(root) diff --git a/tests/datasets/test_inaturalist.py b/tests/datasets/test_inaturalist.py index 0f9a5424875..a1e255d7745 100644 --- a/tests/datasets/test_inaturalist.py +++ b/tests/datasets/test_inaturalist.py @@ -38,7 +38,7 @@ def test_or(self, dataset: INaturalist) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - INaturalist(str(tmp_path)) + INaturalist(tmp_path) def test_invalid_query(self, dataset: INaturalist) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_inria.py b/tests/datasets/test_inria.py index 21bcb1a900d..41ba41dee69 100644 --- a/tests/datasets/test_inria.py +++ b/tests/datasets/test_inria.py @@ -50,7 +50,7 @@ def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None: def test_not_downloaded(self, tmp_path: str) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - InriaAerialImageLabeling(str(tmp_path)) + InriaAerialImageLabeling(tmp_path) def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None: InriaAerialImageLabeling.md5 = 'randommd5hash123' diff --git a/tests/datasets/test_iobench.py b/tests/datasets/test_iobench.py index 747d9ed1464..f0aa546ed6f 100644 --- a/tests/datasets/test_iobench.py +++ b/tests/datasets/test_iobench.py @@ -24,7 +24,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -36,7 +36,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> IOBench: url = os.path.join('tests', 'data', 'iobench', '{}.tar.gz') monkeypatch.setattr(IOBench, 'url', url) monkeypatch.setitem(IOBench.md5s, 'preprocessed', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return IOBench(root, transforms=transforms, download=True, checksum=True) @@ -68,14 +68,14 @@ def test_already_extracted(self, dataset: IOBench) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'iobench', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) IOBench(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - IOBench(str(tmp_path)) + IOBench(tmp_path) def test_invalid_query(self, dataset: IOBench) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py index f760ae89058..8e6b892f361 100644 --- a/tests/datasets/test_l7irish.py +++ b/tests/datasets/test_l7irish.py @@ -25,7 +25,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -41,7 +41,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L7Irish: url = os.path.join('tests', 'data', 'l7irish', '{}.tar.gz') monkeypatch.setattr(L7Irish, 'url', url) monkeypatch.setattr(L7Irish, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return L7Irish(root, transforms=transforms, download=True, checksum=True) @@ -75,14 +75,14 @@ def test_already_extracted(self, dataset: L7Irish) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) L7Irish(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - L7Irish(str(tmp_path)) + L7Irish(tmp_path) def test_plot_prediction(self, dataset: L7Irish) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py index d00cebb131a..f9a74bf3339 100644 --- a/tests/datasets/test_l8biome.py +++ b/tests/datasets/test_l8biome.py @@ -25,7 +25,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -41,7 +41,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L8Biome: url = os.path.join('tests', 'data', 'l8biome', '{}.tar.gz') monkeypatch.setattr(L8Biome, 'url', url) monkeypatch.setattr(L8Biome, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return L8Biome(root, transforms=transforms, download=True, checksum=True) @@ -75,14 +75,14 @@ def test_already_extracted(self, dataset: L8Biome) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) L8Biome(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - L8Biome(str(tmp_path)) + L8Biome(tmp_path) def test_plot_prediction(self, dataset: L8Biome) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 7c81f257250..5ee43fd795d 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -22,7 +22,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -34,7 +34,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LandCoverAIGeo: monkeypatch.setattr(LandCoverAIGeo, 'md5', md5) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') monkeypatch.setattr(LandCoverAIGeo, 'url', url) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True) @@ -49,13 +49,13 @@ def test_already_extracted(self, dataset: LandCoverAIGeo) -> None: def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) LandCoverAIGeo(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LandCoverAIGeo(str(tmp_path)) + LandCoverAIGeo(tmp_path) def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) @@ -89,7 +89,7 @@ def dataset( monkeypatch.setattr(LandCoverAI, 'url', url) sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' monkeypatch.setattr(LandCoverAI, 'sha256', sha256) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LandCoverAI(root, split, transforms, download=True, checksum=True) @@ -115,13 +115,13 @@ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> N sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' monkeypatch.setattr(LandCoverAI, 'sha256', sha256) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) LandCoverAI(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LandCoverAI(str(tmp_path)) + LandCoverAI(tmp_path) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index 51e6a6fef24..621910ccd00 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -71,7 +71,7 @@ def test_plot_wrong_bands(self, dataset: Landsat8) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Landsat8(str(tmp_path)) + Landsat8(tmp_path) def test_invalid_query(self, dataset: Landsat8) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index cbec555746c..030f819b78b 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -16,7 +16,7 @@ from torchgeo.datasets import LEVIRCD, DatasetNotFoundError, LEVIRCDPlus -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -45,7 +45,7 @@ def dataset( } monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) monkeypatch.setattr(LEVIRCD, 'splits', splits) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LEVIRCD(root, split, transforms, download=True, checksum=True) @@ -71,7 +71,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LEVIRCD(str(tmp_path)) + LEVIRCD(tmp_path) def test_plot(self, dataset: LEVIRCD) -> None: dataset.plot(dataset[0], suptitle='Test') @@ -93,7 +93,7 @@ def dataset( monkeypatch.setattr(LEVIRCDPlus, 'md5', md5) url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip') monkeypatch.setattr(LEVIRCDPlus, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LEVIRCDPlus(root, split, transforms, download=True, checksum=True) @@ -119,7 +119,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LEVIRCDPlus(str(tmp_path)) + LEVIRCDPlus(tmp_path) def test_plot(self, dataset: LEVIRCDPlus) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py index be36bec2f1e..abd8d9bd26e 100644 --- a/tests/datasets/test_loveda.py +++ b/tests/datasets/test_loveda.py @@ -16,7 +16,7 @@ from torchgeo.datasets import DatasetNotFoundError, LoveDA -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -48,7 +48,7 @@ def dataset( monkeypatch.setattr(LoveDA, 'info_dict', info_dict) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LoveDA( @@ -84,7 +84,7 @@ def test_invalid_scene(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LoveDA(str(tmp_path)) + LoveDA(tmp_path) def test_plot(self, dataset: LoveDA) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_mapinwild.py b/tests/datasets/test_mapinwild.py index aff7d200099..c60a426d700 100644 --- a/tests/datasets/test_mapinwild.py +++ b/tests/datasets/test_mapinwild.py @@ -18,7 +18,7 @@ from torchgeo.datasets import DatasetNotFoundError, MapInWild -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -53,7 +53,7 @@ def dataset( urls = os.path.join('tests', 'data', 'mapinwild') monkeypatch.setattr(MapInWild, 'url', urls) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() @@ -98,12 +98,12 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - MapInWild(root=str(tmp_path)) + MapInWild(root=tmp_path) def test_downloaded_not_extracted(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'mapinwild', '*', '*') pathname_glob = glob.glob(pathname) - root = str(tmp_path) + root = tmp_path for zipfile in pathname_glob: shutil.copy(zipfile, root) MapInWild(root, download=False) @@ -111,7 +111,7 @@ def test_downloaded_not_extracted(self, tmp_path: Path) -> None: def test_corrupted(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'mapinwild', '**', '*.zip') pathname_glob = glob.glob(pathname, recursive=True) - root = str(tmp_path) + root = tmp_path for zipfile in pathname_glob: shutil.copy(zipfile, root) splitfile = os.path.join( @@ -121,10 +121,10 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'mask.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - MapInWild(root=str(tmp_path), download=True, checksum=True) + MapInWild(root=tmp_path, download=True, checksum=True) def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None: - MapInWild(root=str(tmp_path), modality=dataset.modality, download=True) + MapInWild(root=tmp_path, modality=dataset.modality, download=True) def test_plot(self, dataset: MapInWild) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 349006ce248..8b1dcef988a 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -39,18 +39,18 @@ def test_len(self, dataset: MillionAID) -> None: def test_not_found(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - MillionAID(str(tmp_path)) + MillionAID(tmp_path) def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'millionaid', 'train.zip') shutil.copy(url, tmp_path) - MillionAID(str(tmp_path)) + MillionAID(tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'train.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - MillionAID(str(tmp_path), checksum=True) + MillionAID(tmp_path, checksum=True) def test_plot(self, dataset: MillionAID) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index 580b309b432..ea54ec881a3 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -51,7 +51,7 @@ def test_plot(self, dataset: NAIP) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NAIP(str(tmp_path)) + NAIP(tmp_path) def test_invalid_query(self, dataset: NAIP) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 588cd89174a..1697195f465 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -45,7 +45,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7'] monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return NASAMarineDebris(root, transforms, download=True, checksum=True) @@ -63,15 +63,15 @@ def test_len(self, dataset: NASAMarineDebris) -> None: def test_already_downloaded( self, dataset: NASAMarineDebris, tmp_path: Path ) -> None: - NASAMarineDebris(root=str(tmp_path), download=True) + NASAMarineDebris(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: NASAMarineDebris, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - os.makedirs(str(tmp_path), exist_ok=True) + os.makedirs(tmp_path, exist_ok=True) Collection().download(output_dir=str(tmp_path)) - NASAMarineDebris(root=str(tmp_path), download=False) + NASAMarineDebris(root=tmp_path, download=False) def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None: filenames = NASAMarineDebris.filenames @@ -79,7 +79,7 @@ def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, filename), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): - NASAMarineDebris(root=str(tmp_path), download=False, checksum=True) + NASAMarineDebris(root=tmp_path, download=False, checksum=True) def test_corrupted_new_download( self, tmp_path: Path, monkeypatch: MonkeyPatch @@ -87,11 +87,11 @@ def test_corrupted_new_download( with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted) - NASAMarineDebris(root=str(tmp_path), download=True, checksum=True) + NASAMarineDebris(root=tmp_path, download=True, checksum=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NASAMarineDebris(str(tmp_path)) + NASAMarineDebris(tmp_path) def test_plot(self, dataset: NASAMarineDebris) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_nccm.py b/tests/datasets/test_nccm.py index 8def40c4c4e..cbb2d5646a4 100644 --- a/tests/datasets/test_nccm.py +++ b/tests/datasets/test_nccm.py @@ -22,7 +22,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -43,7 +43,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: } monkeypatch.setattr(NCCM, 'urls', urls) transforms = nn.Identity() - root = str(tmp_path) + root = tmp_path return NCCM(root, transforms=transforms, download=True, checksum=True) def test_getitem(self, dataset: NCCM) -> None: @@ -84,7 +84,7 @@ def test_plot_prediction(self, dataset: NCCM) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NCCM(str(tmp_path)) + NCCM(tmp_path) def test_invalid_query(self, dataset: NCCM) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py index c24220100b1..098e22bedcb 100644 --- a/tests/datasets/test_nlcd.py +++ b/tests/datasets/test_nlcd.py @@ -22,7 +22,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -42,7 +42,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD: ) monkeypatch.setattr(NLCD, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return NLCD( root, @@ -84,7 +84,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( 'tests', 'data', 'nlcd', 'nlcd_2019_land_cover_l48_20210604.zip' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) NLCD(root, years=[2019]) @@ -93,7 +93,7 @@ def test_invalid_year(self, tmp_path: Path) -> None: AssertionError, match='NLCD data product only exists for the following years:', ): - NLCD(str(tmp_path), years=[1996]) + NLCD(tmp_path, years=[1996]) def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): @@ -117,7 +117,7 @@ def test_plot_prediction(self, dataset: NLCD) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NLCD(str(tmp_path)) + NLCD(tmp_path) def test_invalid_query(self, dataset: NLCD) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_openbuildings.py b/tests/datasets/test_openbuildings.py index 38610ee7195..a322e4e715e 100644 --- a/tests/datasets/test_openbuildings.py +++ b/tests/datasets/test_openbuildings.py @@ -26,7 +26,7 @@ class TestOpenBuildings: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings: - root = str(tmp_path) + root = tmp_path shutil.copy( os.path.join('tests', 'data', 'openbuildings', 'tiles.geojson'), root ) @@ -55,7 +55,7 @@ def test_no_shapes_to_rasterize( def test_not_download(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - OpenBuildings(str(tmp_path)) + OpenBuildings(tmp_path) def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None: with open(os.path.join(tmp_path, '000_buildings.csv.gz'), 'w') as f: diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index cd1c80a443b..11001f8bfff 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -18,7 +18,7 @@ from torchgeo.datasets import OSCD, DatasetNotFoundError, RGBBandsMissingError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -63,7 +63,7 @@ def dataset( monkeypatch.setattr(OSCD, 'urls', urls) bands, split = request.param - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return OSCD( root, split, bands, transforms=transforms, download=True, checksum=True @@ -101,14 +101,14 @@ def test_already_extracted(self, dataset: OSCD) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'oscd', '*Onera*.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) OSCD(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - OSCD(str(tmp_path)) + OSCD(tmp_path) def test_plot(self, dataset: OSCD) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index 62ff5f913e6..5184eea4c6d 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -17,7 +17,7 @@ from torchgeo.datasets import PASTIS, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -38,7 +38,7 @@ def dataset( monkeypatch.setattr(PASTIS, 'md5', md5) url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip') monkeypatch.setattr(PASTIS, 'url', url) - root = str(tmp_path) + root = tmp_path folds = request.param['folds'] bands = request.param['bands'] mode = request.param['mode'] @@ -75,19 +75,19 @@ def test_already_extracted(self, dataset: PASTIS) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) PASTIS(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - PASTIS(str(tmp_path)) + PASTIS(tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'PASTIS-R.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - PASTIS(root=str(tmp_path), checksum=True) + PASTIS(root=tmp_path, checksum=True) def test_invalid_fold(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_patternnet.py b/tests/datasets/test_patternnet.py index 915d7388bad..9c2a7585354 100644 --- a/tests/datasets/test_patternnet.py +++ b/tests/datasets/test_patternnet.py @@ -15,7 +15,7 @@ from torchgeo.datasets import DatasetNotFoundError, PatternNet -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -27,7 +27,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> PatternNet: monkeypatch.setattr(PatternNet, 'md5', md5) url = os.path.join('tests', 'data', 'patternnet', 'PatternNet.zip') monkeypatch.setattr(PatternNet, 'url', url) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return PatternNet(root, transforms, download=True, checksum=True) @@ -42,18 +42,18 @@ def test_len(self, dataset: PatternNet) -> None: assert len(dataset) == 2 def test_already_downloaded(self, dataset: PatternNet, tmp_path: Path) -> None: - PatternNet(root=str(tmp_path), download=True) + PatternNet(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: PatternNet, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - PatternNet(root=str(tmp_path), download=False) + download_url(dataset.url, root=tmp_path) + PatternNet(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - PatternNet(str(tmp_path)) + PatternNet(tmp_path) def test_plot(self, dataset: PatternNet) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py index 4529d937690..9de329686d0 100644 --- a/tests/datasets/test_potsdam.py +++ b/tests/datasets/test_potsdam.py @@ -43,9 +43,9 @@ def test_extract(self, tmp_path: Path) -> None: root = os.path.join('tests', 'data', 'potsdam') for filename in ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']: shutil.copyfile( - os.path.join(root, filename), os.path.join(str(tmp_path), filename) + os.path.join(root, filename), os.path.join(tmp_path, filename) ) - Potsdam2D(root=str(tmp_path)) + Potsdam2D(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, '4_Ortho_RGBIR.zip'), 'w') as f: @@ -53,7 +53,7 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, '5_Labels_all.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - Potsdam2D(root=str(tmp_path), checksum=True) + Potsdam2D(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -61,7 +61,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Potsdam2D(str(tmp_path)) + Potsdam2D(tmp_path) def test_plot(self, dataset: Potsdam2D) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_prisma.py b/tests/datasets/test_prisma.py index 89ab52c7275..d43af61e97f 100644 --- a/tests/datasets/test_prisma.py +++ b/tests/datasets/test_prisma.py @@ -50,7 +50,7 @@ def test_plot(self, dataset: PRISMA) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - PRISMA(str(tmp_path)) + PRISMA(tmp_path) def test_invalid_query(self, dataset: PRISMA) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index 636e9d7a666..bddd26a7492 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -18,7 +18,7 @@ pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -32,7 +32,7 @@ def dataset( md5 = '127d0d6a1f82d517129535f50053a4c9' monkeypatch.setattr(QuakeSet, 'md5', md5) monkeypatch.setattr(QuakeSet, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return QuakeSet( @@ -50,11 +50,11 @@ def test_len(self, dataset: QuakeSet) -> None: assert len(dataset) == 8 def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: - QuakeSet(root=str(tmp_path), download=True) + QuakeSet(root=tmp_path, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - QuakeSet(str(tmp_path)) + QuakeSet(tmp_path) def test_plot(self, dataset: QuakeSet) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index 092e7cf2f1f..9fc0277df5d 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -15,7 +15,7 @@ from torchgeo.datasets import DatasetNotFoundError, ReforesTree -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -31,7 +31,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree: monkeypatch.setattr(ReforesTree, 'url', url) monkeypatch.setattr(ReforesTree, 'md5', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ReforesTree( root=root, transforms=transforms, download=True, checksum=True @@ -57,17 +57,17 @@ def test_len(self, dataset: ReforesTree) -> None: def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'reforestree', 'reforesTree.zip') shutil.copy(url, tmp_path) - ReforesTree(root=str(tmp_path)) + ReforesTree(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'reforesTree.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - ReforesTree(root=str(tmp_path), checksum=True) + ReforesTree(root=tmp_path, checksum=True) def test_not_found(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ReforesTree(str(tmp_path)) + ReforesTree(tmp_path) def test_plot(self, dataset: ReforesTree) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index d52d2d01194..5c064d36f8c 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -18,7 +18,7 @@ pytest.importorskip('rarfile', minversion='4') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -52,7 +52,7 @@ def dataset( 'test': '7760b1960c9a3ff46fb985810815e14d', }, ) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return RESISC45(root, split, transforms, download=True, checksum=True) @@ -68,18 +68,18 @@ def test_len(self, dataset: RESISC45) -> None: assert len(dataset) == 9 def test_already_downloaded(self, dataset: RESISC45, tmp_path: Path) -> None: - RESISC45(root=str(tmp_path), download=True) + RESISC45(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: RESISC45, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - RESISC45(root=str(tmp_path), download=False) + download_url(dataset.url, root=tmp_path) + RESISC45(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - RESISC45(str(tmp_path)) + RESISC45(tmp_path) def test_plot(self, dataset: RESISC45) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py index ddf5b5df7fb..d08532e7507 100644 --- a/tests/datasets/test_rwanda_field_boundary.py +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -33,7 +33,7 @@ def dataset( monkeypatch.setattr(RwandaFieldBoundary, 'url', url) monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1}) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return RwandaFieldBoundary(root, split, transforms=transforms, download=True) @@ -60,7 +60,7 @@ def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - RwandaFieldBoundary(str(tmp_path)) + RwandaFieldBoundary(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_seasonet.py b/tests/datasets/test_seasonet.py index 9178dcb1217..ea0d87cb6b7 100644 --- a/tests/datasets/test_seasonet.py +++ b/tests/datasets/test_seasonet.py @@ -18,7 +18,9 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, SeasoNet -def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None: +def download_url( + url: str, root: str | Path, md5: str, *args: str, **kwargs: str +) -> None: shutil.copy(url, root) torchgeo.datasets.utils.check_integrity( os.path.join(root, os.path.basename(url)), md5 @@ -95,7 +97,7 @@ def dataset( 'url', os.path.join('tests', 'data', 'seasonet', 'meta.csv'), ) - root = str(tmp_path) + root = tmp_path split, seasons, bands, grids, concat_seasons = request.param transforms = nn.Identity() return SeasoNet( @@ -141,14 +143,14 @@ def test_already_extracted(self, dataset: SeasoNet) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: paths = os.path.join('tests', 'data', 'seasonet', '*.*') - root = str(tmp_path) + root = tmp_path for path in glob.iglob(paths): shutil.copy(path, root) SeasoNet(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SeasoNet(str(tmp_path), download=False) + SeasoNet(tmp_path, download=False) def test_out_of_bounds(self, dataset: SeasoNet) -> None: with pytest.raises(IndexError): diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py index ed273a8810c..497bfe2528a 100644 --- a/tests/datasets/test_seco.py +++ b/tests/datasets/test_seco.py @@ -22,7 +22,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -56,7 +56,7 @@ def dataset( monkeypatch.setitem( SeasonalContrastS2.metadata['1m'], 'md5', '3bb3fcf90f5de7d5781ce0cb85fd20af' ) - root = str(tmp_path) + root = tmp_path version, seasons, bands = request.param transforms = nn.Identity() return SeasonalContrastS2( @@ -88,7 +88,7 @@ def test_already_extracted(self, dataset: SeasonalContrastS2) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'seco', '*.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) SeasonalContrastS2(root) @@ -103,7 +103,7 @@ def test_invalid_band(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SeasonalContrastS2(str(tmp_path)) + SeasonalContrastS2(tmp_path) def test_plot(self, dataset: SeasonalContrastS2) -> None: x = dataset[0] diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 5732eaf18dd..b7ff8e8e978 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -66,10 +66,10 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SEN12MS(str(tmp_path), checksum=True) + SEN12MS(tmp_path, checksum=True) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SEN12MS(str(tmp_path), checksum=False) + SEN12MS(tmp_path, checksum=False) def test_check_integrity_light(self) -> None: root = os.path.join('tests', 'data', 'sen12ms') diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index 28cf6eb1e67..ee4933b44f7 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -70,7 +70,7 @@ def test_plot(self, dataset: Sentinel2) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Sentinel1(str(tmp_path)) + Sentinel1(tmp_path) def test_empty_bands(self) -> None: with pytest.raises(AssertionError, match="'bands' cannot be an empty list"): @@ -132,7 +132,7 @@ def test_or(self, dataset: Sentinel2) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Sentinel2(str(tmp_path)) + Sentinel2(tmp_path) def test_plot(self, dataset: Sentinel2) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index d4deb975b3b..a58cf4196c1 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -19,7 +19,7 @@ pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -40,7 +40,7 @@ def dataset( url = os.path.join('tests', 'data', 'skippd', '{}') monkeypatch.setattr(SKIPPD, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SKIPPD( root=root, @@ -59,7 +59,7 @@ def test_already_downloaded(self, tmp_path: Path, task: str) -> None: pathname = os.path.join( 'tests', 'data', 'skippd', f'2017_2019_images_pv_processed_{task}.zip' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) SKIPPD(root=root, task=task) @@ -84,7 +84,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SKIPPD(str(tmp_path)) + SKIPPD(tmp_path) def test_plot(self, dataset: SKIPPD) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 1caf86b6c30..bc88662c16a 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -58,7 +58,7 @@ def test_invalid_bands(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - So2Sat(str(tmp_path)) + So2Sat(tmp_path) def test_plot(self, dataset: So2Sat) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_south_africa_crop_type.py b/tests/datasets/test_south_africa_crop_type.py index 75b014e2227..00671045572 100644 --- a/tests/datasets/test_south_africa_crop_type.py +++ b/tests/datasets/test_south_africa_crop_type.py @@ -52,7 +52,7 @@ def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SouthAfricaCropType(str(tmp_path)) + SouthAfricaCropType(tmp_path) def test_plot(self) -> None: path = os.path.join('tests', 'data', 'south_africa_crop_type') diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index c119dc2749b..53a37a38e30 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -21,7 +21,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -37,7 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe ) monkeypatch.setattr(SouthAmericaSoybean, 'url', url) - root = str(tmp_path) + root = tmp_path return SouthAmericaSoybean( paths=root, years=[2002, 2021], @@ -70,7 +70,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( 'tests', 'data', 'south_america_soybean', 'SouthAmerica_Soybean_2002.tif' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) SouthAmericaSoybean(root) @@ -89,7 +89,7 @@ def test_plot_prediction(self, dataset: SouthAmericaSoybean) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SouthAmericaSoybean(str(tmp_path)) + SouthAmericaSoybean(tmp_path) def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 2676af497fd..38b6d2fbf21 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -68,7 +68,7 @@ def dataset( # Refer https://github.com/python/mypy/issues/1032 monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet1( root, image=request.param, transforms=transforms, download=True, api_key='' @@ -93,7 +93,7 @@ def test_already_downloaded(self, dataset: SpaceNet1) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet1(str(tmp_path)) + SpaceNet1(tmp_path) def test_plot(self, dataset: SpaceNet1) -> None: x = dataset[0].copy() @@ -118,7 +118,7 @@ def dataset( } monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet2( root, @@ -149,7 +149,7 @@ def test_already_downloaded(self, dataset: SpaceNet2) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet2(str(tmp_path)) + SpaceNet2(tmp_path) def test_collection_checksum(self, dataset: SpaceNet2) -> None: dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123' @@ -177,7 +177,7 @@ def dataset( } monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet3( root, @@ -209,7 +209,7 @@ def test_already_downloaded(self, dataset: SpaceNet3) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet3(str(tmp_path)) + SpaceNet3(tmp_path) def test_collection_checksum(self, dataset: SpaceNet3) -> None: dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123' @@ -240,7 +240,7 @@ def dataset( test_angles = ['nadir', 'off-nadir', 'very-off-nadir'] monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet4( root, @@ -273,7 +273,7 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet4(str(tmp_path)) + SpaceNet4(tmp_path) def test_collection_checksum(self, dataset: SpaceNet4) -> None: dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123' @@ -303,7 +303,7 @@ def dataset( } monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet5( root, @@ -335,7 +335,7 @@ def test_already_downloaded(self, dataset: SpaceNet5) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet5(str(tmp_path)) + SpaceNet5(tmp_path) def test_collection_checksum(self, dataset: SpaceNet5) -> None: dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123' @@ -359,7 +359,7 @@ def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet6: monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet6( root, image=request.param, transforms=transforms, download=True, api_key='' @@ -405,7 +405,7 @@ def dataset( } monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SpaceNet7( root, split=request.param, transforms=transforms, download=True, api_key='' @@ -429,7 +429,7 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet7(str(tmp_path)) + SpaceNet7(tmp_path) def test_collection_checksum(self, dataset: SpaceNet4) -> None: dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123' diff --git a/tests/datasets/test_ssl4eo.py b/tests/datasets/test_ssl4eo.py index ad45798946e..7caf14aac9f 100644 --- a/tests/datasets/test_ssl4eo.py +++ b/tests/datasets/test_ssl4eo.py @@ -18,7 +18,7 @@ from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -61,7 +61,7 @@ def dataset( } monkeypatch.setattr(SSL4EOL, 'checksums', checksums) - root = str(tmp_path) + root = tmp_path split, seasons = request.param transforms = nn.Identity() return SSL4EOL(root, split, seasons, transforms, download=True, checksum=True) @@ -88,14 +88,14 @@ def test_already_extracted(self, dataset: SSL4EOL) -> None: def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'ssl4eo', 'l', '*.tar.gz*') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) SSL4EOL(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SSL4EOL(str(tmp_path)) + SSL4EOL(tmp_path) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -148,7 +148,7 @@ def test_extract(self, tmp_path: Path) -> None: os.path.join('tests', 'data', 'ssl4eo', 's12', filename), tmp_path / filename, ) - SSL4EOS12(str(tmp_path)) + SSL4EOS12(tmp_path) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -156,7 +156,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SSL4EOS12(str(tmp_path)) + SSL4EOS12(tmp_path) def test_plot(self, dataset: SSL4EOS12) -> None: sample = dataset[0] diff --git a/tests/datasets/test_ssl4eo_benchmark.py b/tests/datasets/test_ssl4eo_benchmark.py index db1d36f73b0..153c8738091 100644 --- a/tests/datasets/test_ssl4eo_benchmark.py +++ b/tests/datasets/test_ssl4eo_benchmark.py @@ -25,7 +25,7 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -43,7 +43,7 @@ def dataset( monkeypatch.setattr( torchgeo.datasets.ssl4eo_benchmark, 'download_url', download_url ) - root = str(tmp_path) + root = tmp_path url = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '{}.tar.gz') monkeypatch.setattr(SSL4EOLBenchmark, 'url', url) @@ -140,14 +140,14 @@ def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) SSL4EOLBenchmark(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SSL4EOLBenchmark(str(tmp_path)) + SSL4EOLBenchmark(tmp_path) def test_plot(self, dataset: SSL4EOLBenchmark) -> None: sample = dataset[0] diff --git a/tests/datasets/test_sustainbench_crop_yield.py b/tests/datasets/test_sustainbench_crop_yield.py index 36e746aaf92..4a8fcefb533 100644 --- a/tests/datasets/test_sustainbench_crop_yield.py +++ b/tests/datasets/test_sustainbench_crop_yield.py @@ -16,7 +16,7 @@ from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -34,7 +34,7 @@ def dataset( url = os.path.join('tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip') monkeypatch.setattr(SustainBenchCropYield, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path split = request.param countries = ['argentina', 'brazil', 'usa'] transforms = nn.Identity() @@ -49,7 +49,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( 'tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) SustainBenchCropYield(root) @@ -72,7 +72,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SustainBenchCropYield(str(tmp_path)) + SustainBenchCropYield(tmp_path) def test_plot(self, dataset: SustainBenchCropYield) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index bedeb588c66..be6d8143cc1 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -17,7 +17,7 @@ from torchgeo.datasets import DatasetNotFoundError, UCMerced -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -51,7 +51,7 @@ def dataset( 'test': 'a01fa9f13333bb176fc1bfe26ff4c711', }, ) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return UCMerced(root, split, transforms, download=True, checksum=True) @@ -71,18 +71,18 @@ def test_add(self, dataset: UCMerced) -> None: assert len(ds) == 8 def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None: - UCMerced(root=str(tmp_path), download=True) + UCMerced(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: UCMerced, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - UCMerced(root=str(tmp_path), download=False) + download_url(dataset.url, root=tmp_path) + UCMerced(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - UCMerced(str(tmp_path)) + UCMerced(tmp_path) def test_plot(self, dataset: UCMerced) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 0566a1f3153..75916ab7e81 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -17,7 +17,7 @@ from torchgeo.datasets import DatasetNotFoundError, USAVars -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -73,7 +73,7 @@ def dataset( } monkeypatch.setattr(USAVars, 'split_metadata', split_metadata) - root = str(tmp_path) + root = tmp_path split, labels = request.param transforms = nn.Identity() @@ -109,7 +109,7 @@ def test_already_extracted(self, dataset: USAVars) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'usavars', 'uar.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) csvs = [ 'elevation.csv', @@ -130,7 +130,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - USAVars(str(tmp_path)) + USAVars(tmp_path) def test_plot(self, dataset: USAVars) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index d6c9bc15c38..f0b4a104751 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -65,7 +65,7 @@ def fetch_collection(collection_id: str, **kwargs: str) -> Collection: return Collection() -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -85,7 +85,7 @@ def test_extract_archive(src: str, tmp_path: Path) -> None: pytest.importorskip('rarfile', minversion='4') if src.startswith('chesapeake'): pytest.importorskip('zipfile_deflate64') - extract_archive(os.path.join('tests', 'data', src), str(tmp_path)) + extract_archive(os.path.join('tests', 'data', src), tmp_path) def test_unsupported_scheme() -> None: @@ -98,8 +98,7 @@ def test_unsupported_scheme() -> None: def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) download_and_extract_archive( - os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'), - str(tmp_path), + os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'), tmp_path ) @@ -108,7 +107,7 @@ def test_download_radiant_mlhub_dataset( ) -> None: radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset) - download_radiant_mlhub_dataset('', str(tmp_path)) + download_radiant_mlhub_dataset('', tmp_path) def test_download_radiant_mlhub_collection( @@ -116,7 +115,7 @@ def test_download_radiant_mlhub_collection( ) -> None: radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - download_radiant_mlhub_collection('', str(tmp_path)) + download_radiant_mlhub_collection('', tmp_path) class TestBoundingBox: diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py index e4b36b99edd..7b6cbe878f9 100644 --- a/tests/datasets/test_vaihingen.py +++ b/tests/datasets/test_vaihingen.py @@ -49,9 +49,9 @@ def test_extract(self, tmp_path: Path) -> None: ] for filename in filenames: shutil.copyfile( - os.path.join(root, filename), os.path.join(str(tmp_path), filename) + os.path.join(root, filename), os.path.join(tmp_path, filename) ) - Vaihingen2D(root=str(tmp_path)) + Vaihingen2D(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: filenames = [ @@ -62,7 +62,7 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, filename), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - Vaihingen2D(root=str(tmp_path), checksum=True) + Vaihingen2D(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -70,7 +70,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Vaihingen2D(str(tmp_path)) + Vaihingen2D(tmp_path) def test_plot(self, dataset: Vaihingen2D) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index dee46c1db88..4222867ffc5 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -20,7 +20,7 @@ pytest.importorskip('rarfile', minversion='4') -def download_url(url: str, root: str, *args: str) -> None: +def download_url(url: str, root: str | Path, *args: str) -> None: shutil.copy(url, root) @@ -39,7 +39,7 @@ def dataset( monkeypatch.setitem(VHR10.target_meta, 'url', url) md5 = '567c4cd8c12624864ff04865de504c58' monkeypatch.setitem(VHR10.target_meta, 'md5', md5) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return VHR10(root, split, transforms, download=True, checksum=True) @@ -78,7 +78,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - VHR10(str(tmp_path)) + VHR10(tmp_path) def test_plot(self, dataset: VHR10) -> None: pytest.importorskip('skimage', minversion='0.19') diff --git a/tests/datasets/test_western_usa_live_fuel_moisture.py b/tests/datasets/test_western_usa_live_fuel_moisture.py index e2c9120ae02..02b8870f4fd 100644 --- a/tests/datasets/test_western_usa_live_fuel_moisture.py +++ b/tests/datasets/test_western_usa_live_fuel_moisture.py @@ -37,7 +37,7 @@ def dataset( monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) md5 = 'ecbc9269dd27c4efe7aa887960054351' monkeypatch.setattr(WesternUSALiveFuelMoisture, 'md5', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return WesternUSALiveFuelMoisture( root, transforms=transforms, download=True, api_key='', checksum=True @@ -60,13 +60,13 @@ def test_already_downloaded(self, tmp_path: Path) -> None: 'western_usa_live_fuel_moisture', 'su_sar_moisture_content.tar.gz', ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) WesternUSALiveFuelMoisture(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - WesternUSALiveFuelMoisture(str(tmp_path)) + WesternUSALiveFuelMoisture(tmp_path) def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None: with pytest.raises(AssertionError, match='Invalid input variable name.'): diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 7689acf5f78..c54b597fadf 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -61,7 +61,7 @@ def test_extract(self, tmp_path: Path) -> None: ), os.path.join(tmp_path, 'test_images_labels_targets.tar.gz'), ) - XView2(root=str(tmp_path)) + XView2(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open( @@ -73,7 +73,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ) as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - XView2(root=str(tmp_path), checksum=True) + XView2(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -81,7 +81,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - XView2(str(tmp_path)) + XView2(tmp_path) def test_plot(self, dataset: XView2) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index bea0d9e8519..7ad0dfe550f 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -17,7 +17,7 @@ pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: +def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -33,7 +33,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] monkeypatch.setattr(ZueriCrop, 'urls', urls) monkeypatch.setattr(ZueriCrop, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True) @@ -67,7 +67,7 @@ def test_already_downloaded(self, dataset: ZueriCrop) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ZueriCrop(str(tmp_path)) + ZueriCrop(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(ValueError): diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index be0e5996b17..6dcd690c735 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive, lazy_import +from .utils import Path, download_and_extract_archive, lazy_import class ADVANCE(NonGeoDataset): @@ -88,7 +88,7 @@ class ADVANCE(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -151,7 +151,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -169,7 +169,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: ] return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -185,7 +185,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target audio for a single image. Args: diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index e9a8ac844b9..1b80e11555b 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -5,6 +5,7 @@ import json import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -14,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import download_url +from .utils import Path, download_url class AbovegroundLiveWoodyBiomassDensity(RasterDataset): @@ -57,7 +58,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -105,7 +106,7 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) download_url(self.url, self.paths, self.base_filename) with open(os.path.join(self.paths, self.base_filename)) as f: diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index fd325aaa6f8..81256e1f091 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -4,6 +4,7 @@ """AgriFieldNet India Challenge dataset.""" import os +import pathlib import re from collections.abc import Callable, Iterable, Sequence from typing import Any, cast @@ -16,7 +17,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset -from .utils import BoundingBox +from .utils import BoundingBox, Path class AgriFieldNet(RasterDataset): @@ -115,7 +116,7 @@ class AgriFieldNet(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: Sequence[str] = all_bands, @@ -167,10 +168,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Returns: data, label, and field ids at that index """ - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) + filepaths = cast(list[Path], [hit.object for hit in hits]) if not filepaths: raise IndexError( diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 479d0f79ce8..c4ef23061b8 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -12,6 +12,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset +from .utils import Path class AsterGDEM(RasterDataset): @@ -47,7 +48,7 @@ class AsterGDEM(RasterDataset): def __init__( self, - paths: str | list[str] = 'data', + paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 686d5974324..70bd76373d5 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -19,7 +19,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import which +from .utils import Path, which class BeninSmallHolderCashews(NonGeoDataset): @@ -163,7 +163,7 @@ class BeninSmallHolderCashews(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', chip_size: int = 256, stride: int = 128, bands: Sequence[str] = all_bands, diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index e2973d2aff7..0f4c94565e1 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, sort_sentinel2_bands +from .utils import Path, download_url, extract_archive, sort_sentinel2_bands class BigEarthNet(NonGeoDataset): @@ -267,7 +267,7 @@ class BigEarthNet(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: str = 'all', num_classes: int = 19, @@ -486,7 +486,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, filename) self._extract(filepath) - def _download(self, url: str, filename: str, md5: str) -> None: + def _download(self, url: str, filename: Path, md5: str) -> None: """Download the dataset. Args: @@ -499,13 +499,13 @@ def _download(self, url: str, filename: str, md5: str) -> None: url, self.root, filename=filename, md5=md5 if self.checksum else None ) - def _extract(self, filepath: str) -> None: + def _extract(self, filepath: Path) -> None: """Extract the dataset. Args: filepath: path to file to be extracted """ - if not filepath.endswith('.csv'): + if not str(filepath).endswith('.csv'): extract_archive(filepath) def _onehot_labels_to_names( diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index ab440648a17..dc757b96c4b 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import percentile_normalization +from .utils import Path, percentile_normalization class BioMassters(NonGeoDataset): @@ -57,7 +57,7 @@ class BioMassters(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', sensors: Sequence[str] = ['S1', 'S2'], as_time_series: bool = False, @@ -167,7 +167,7 @@ def __len__(self) -> int: """ return len(self.df['num_index'].unique()) - def _load_input(self, filenames: list[str]) -> Tensor: + def _load_input(self, filenames: list[Path]) -> Tensor: """Load the input imagery at the index. Args: @@ -186,7 +186,7 @@ def _load_input(self, filenames: list[str]) -> Tensor: arr = np.concatenate(arr_list, axis=0) return torch.tensor(arr.astype(np.int32)) - def _load_target(self, filename: str) -> Tensor: + def _load_target(self, filename: Path) -> Tensor: """Load the target mask at the index. Args: diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index 2c8105b21f8..f91993763dd 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -4,6 +4,7 @@ """Canadian Building Footprints dataset.""" import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -13,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import Path, check_integrity, download_and_extract_archive class CanadianBuildingFootprints(VectorDataset): @@ -62,7 +63,7 @@ class CanadianBuildingFootprints(VectorDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.00001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -104,7 +105,7 @@ def _check_integrity(self) -> bool: Returns: True if dataset files are found and/or MD5s match, else False """ - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): filepath = os.path.join(self.paths, prov_terr + '.zip') if not check_integrity(filepath, md5 if self.checksum else None): @@ -116,7 +117,7 @@ def _download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): download_and_extract_archive( self.url + prov_terr + '.zip', diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index b2e43d7a1d4..25e42d2b030 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -4,6 +4,7 @@ """CDL dataset.""" import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -14,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class CDL(RasterDataset): @@ -207,7 +208,7 @@ class CDL(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2023], @@ -294,7 +295,7 @@ def _verify(self) -> None: # Check if the zip files have already been downloaded exists = [] - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) for year in self.years: pathname = os.path.join( self.paths, self.zipfile_glob.replace('*', str(year)) @@ -327,7 +328,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) for year in self.years: zipfile_name = self.zipfile_glob.replace('*', str(year)) pathname = os.path.join(self.paths, zipfile_name) diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index 905c2d8496e..61eefbdf35c 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, lazy_import, percentile_normalization +from .utils import Path, download_url, lazy_import, percentile_normalization class ChaBuD(NonGeoDataset): @@ -75,7 +75,7 @@ class ChaBuD(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: list[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 55dddd02cfd..00b415e0959 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -5,6 +5,7 @@ import abc import os +import pathlib import sys from collections.abc import Callable, Iterable, Sequence from typing import Any, cast @@ -26,7 +27,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset, RasterDataset from .nlcd import NLCD -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class Chesapeake(RasterDataset, abc.ABC): @@ -91,7 +92,7 @@ def url(self) -> str: def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -145,7 +146,7 @@ def _verify(self) -> None: return # Check if the zip file has already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) if os.path.exists(os.path.join(self.paths, self.zipfile)): self._extract() return @@ -164,7 +165,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) extract_archive(os.path.join(self.paths, self.zipfile)) def plot( @@ -510,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', splits: Sequence[str] = ['de-train'], layers: Sequence[str] = ['naip-new', 'lc'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -668,7 +669,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - def exists(filename: str) -> bool: + def exists(filename: Path) -> bool: return os.path.exists(os.path.join(self.root, filename)) # Check if the extracted files already exist diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 552693684ba..7c7ed8b630c 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import which +from .utils import Path, which class CloudCoverDetection(NonGeoDataset): @@ -61,7 +61,7 @@ class CloudCoverDetection(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index a1bfd0da56e..91ddbf8c54a 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -4,6 +4,7 @@ """CMS Global Mangrove Canopy dataset.""" import os +import pathlib from collections.abc import Callable from typing import Any @@ -13,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import check_integrity, extract_archive +from .utils import Path, check_integrity, extract_archive class CMSGlobalMangroveCanopy(RasterDataset): @@ -169,7 +170,7 @@ class CMSGlobalMangroveCanopy(RasterDataset): def __init__( self, - paths: str | list[str] = 'data', + paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float | None = None, measurement: str = 'agb', @@ -228,7 +229,7 @@ def _verify(self) -> None: return # Check if the zip file has already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, self.zipfile) if os.path.exists(pathname): if self.checksum and not check_integrity(pathname, self.md5): @@ -240,7 +241,7 @@ def _verify(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, self.zipfile) extract_archive(pathname) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 838123d39b2..cae82e597fc 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import Path, check_integrity, download_and_extract_archive class COWC(NonGeoDataset, abc.ABC): @@ -65,7 +65,7 @@ def filename(self) -> str: def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 400b5ceb63c..30f3a43f634 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, lazy_import +from .utils import Path, download_url, extract_archive, lazy_import class CropHarvest(NonGeoDataset): @@ -96,7 +96,7 @@ class CropHarvest(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -157,7 +157,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_features(self, root: str) -> list[dict[str, str]]: + def _load_features(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -181,7 +181,7 @@ def _load_features(self, root: str) -> list[dict[str, str]]: files.append(dict(chip=chip_path, index=index, dataset=dataset)) return files - def _load_labels(self, root: str) -> pd.DataFrame: + def _load_labels(self, root: Path) -> pd.DataFrame: """Return the paths of the files in the dataset. Args: @@ -196,7 +196,7 @@ def _load_labels(self, root: str) -> pd.DataFrame: df = pd.json_normalize(data['features']) return df - def _load_array(self, path: str) -> Tensor: + def _load_array(self, path: Path) -> Tensor: """Load an individual single pixel time series. Args: diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index feeb6ff0ec2..4e262d5266f 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import which +from .utils import Path, which class CV4AKenyaCropType(NonGeoDataset): @@ -104,7 +104,7 @@ class CV4AKenyaCropType(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', chip_size: int = 256, stride: int = 128, bands: Sequence[str] = all_bands, diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 747463b69de..2a21832703a 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import which +from .utils import Path, which class TropicalCyclone(NonGeoDataset): @@ -53,7 +53,7 @@ class TropicalCyclone(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index a986e43d308..fcd9fb7bac2 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -16,6 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -100,7 +101,7 @@ class DeepGlobeLandCover(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index b9cd1556f9f..dfc43927f30 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive, percentile_normalization +from .utils import Path, check_integrity, extract_archive, percentile_normalization class DFC2022(NonGeoDataset): @@ -137,7 +137,7 @@ class DFC2022(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -224,7 +224,7 @@ def _load_files(self) -> list[dict[str, str]]: return files - def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: + def _load_image(self, path: Path, shape: Sequence[int] | None = None) -> Tensor: """Load a single image. Args: @@ -241,7 +241,7 @@ def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: tensor = torch.from_numpy(array) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index f30b75bcbeb..5f6fb267751 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, disambiguate_timestamp +from .utils import BoundingBox, Path, disambiguate_timestamp class EDDMapS(GeoDataset): @@ -42,7 +42,7 @@ class EDDMapS(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = 'data') -> None: + def __init__(self, root: Path = 'data') -> None: """Initialize a new Dataset instance. Args: diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 0ca0a9bafe3..9b22d80c03e 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -23,7 +23,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class EnviroAtlas(GeoDataset): @@ -253,7 +253,7 @@ class EnviroAtlas(GeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', splits: Sequence[str] = ['pittsburgh_pa-2010_1m-train'], layers: Sequence[str] = ['naip', 'prior'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -414,7 +414,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - def exists(filename: str) -> bool: + def exists(filename: Path) -> bool: return os.path.exists(os.path.join(self.root, 'enviroatlas_lotp', filename)) # Check if the extracted files already exist diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 04157d28bab..ed06309f91a 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -5,6 +5,7 @@ import glob import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -14,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class Esri2020(RasterDataset): @@ -69,7 +70,7 @@ class Esri2020(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -112,7 +113,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, self.zipfile) if glob.glob(pathname): self._extract() @@ -132,7 +133,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) extract_archive(os.path.join(self.paths, self.zipfile)) def plot( diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 44ab7007f9f..ebf1d91f70b 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import Path, download_and_extract_archive class ETCI2021(NonGeoDataset): @@ -81,7 +81,7 @@ class ETCI2021(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -152,7 +152,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -193,7 +193,7 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -210,7 +210,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 9dc431ec1f6..63a6f916526 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -5,6 +5,7 @@ import glob import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -14,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import check_integrity, extract_archive +from .utils import Path, check_integrity, extract_archive class EUDEM(RasterDataset): @@ -84,7 +85,7 @@ class EUDEM(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -125,7 +126,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, self.zipfile_glob) if glob.glob(pathname): for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index 0082dd152b9..bf0f173d4c6 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -5,6 +5,7 @@ import csv import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -16,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import check_integrity, download_and_extract_archive, download_url +from .utils import Path, check_integrity, download_and_extract_archive, download_url class EuroCrops(VectorDataset): @@ -84,7 +85,7 @@ class EuroCrops(VectorDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS = CRS.from_epsg(4326), res: float = 0.00001, classes: list[str] | None = None, @@ -138,7 +139,7 @@ def _check_integrity(self) -> bool: if self.files and not self.checksum: return True - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) filepath = os.path.join(self.paths, self.hcat_fname) if not check_integrity(filepath, self.hcat_md5 if self.checksum else None): @@ -155,7 +156,7 @@ def _download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) download_url( self.base_url + self.hcat_fname, self.paths, @@ -177,7 +178,7 @@ def _load_class_map(self, classes: list[str] | None) -> None: (defaults to all classes) """ if not classes: - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) classes = [] filepath = os.path.join(self.paths, self.hcat_fname) with open(filepath) as f: diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 982917fcdd6..8292257404c 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoClassificationDataset -from .utils import check_integrity, download_url, extract_archive, rasterio_loader +from .utils import Path, check_integrity, download_url, extract_archive, rasterio_loader class EuroSAT(NonGeoClassificationDataset): @@ -97,7 +97,7 @@ class EuroSAT(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -142,7 +142,7 @@ def __init__( for fn in f: valid_fns.add(fn.strip().replace('.jpg', '.tif')) - def is_in_split(x: str) -> bool: + def is_in_split(x: Path) -> bool: return os.path.basename(x) in valid_fns super().__init__( diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index e3476c97128..6d4337f57bc 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -19,10 +19,10 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive -def parse_pascal_voc(path: str) -> dict[str, Any]: +def parse_pascal_voc(path: Path) -> dict[str, Any]: """Read a PASCAL VOC annotation file. Args: @@ -230,7 +230,7 @@ class FAIR1M(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -279,7 +279,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: sample = {'image': image} if self.split != 'test': - label_path = path.replace(self.image_root, self.label_root) + label_path = str(path).replace(self.image_root, self.label_root) label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) boxes, labels = self._load_target(voc['points'], voc['labels']) @@ -298,7 +298,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index 9a5033c8aab..be40dfcf6d6 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class FireRisk(NonGeoClassificationDataset): @@ -68,7 +68,7 @@ class FireRisk(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 1cbae17f961..0eca812bb8b 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -19,10 +19,10 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import Path, check_integrity, download_and_extract_archive, extract_archive -def parse_pascal_voc(path: str) -> dict[str, Any]: +def parse_pascal_voc(path: Path) -> dict[str, Any]: """Read a PASCAL VOC annotation file. Args: @@ -106,7 +106,7 @@ class ForestDamage(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -164,7 +164,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -187,7 +187,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index 259abe481ad..3e6f078faa4 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox +from .utils import BoundingBox, Path def _disambiguate_timestamps( @@ -80,7 +80,7 @@ class GBIF(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = 'data') -> None: + def __init__(self, root: Path = 'data') -> None: """Initialize a new Dataset instance. Args: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d44242d8130..57718fcb4c8 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -7,6 +7,7 @@ import functools import glob import os +import pathlib import re import sys import warnings @@ -34,6 +35,7 @@ from .errors import DatasetNotFoundError from .utils import ( BoundingBox, + Path, array_to_tensor, concat_samples, disambiguate_timestamp, @@ -84,7 +86,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): dataset = landsat7 | landsat8 """ - paths: str | Iterable[str] + paths: Path | Iterable[Path] _crs = CRS.from_epsg(4326) _res = 0.0 @@ -205,7 +207,7 @@ def __setstate__( self, state: tuple[ dict[Any, Any], - list[tuple[int, tuple[float, float, float, float, float, float], str]], + list[tuple[int, tuple[float, float, float, float, float, float], Path]], ], ) -> None: """Define how to unpickle an instance. @@ -288,7 +290,7 @@ def res(self, new_res: float) -> None: self._res = new_res @property - def files(self) -> list[str]: + def files(self) -> list[Path]: """A list of all files in the dataset. Returns: @@ -297,13 +299,13 @@ def files(self) -> list[str]: .. versionadded:: 0.5 """ # Make iterable - if isinstance(self.paths, str): - paths: Iterable[str] = [self.paths] + if isinstance(self.paths, str | pathlib.Path): + paths: Iterable[Path] = [self.paths] else: paths = self.paths # Using set to remove any duplicates if directories are overlapping - files: set[str] = set() + files: set[Path] = set() for path in paths: if os.path.isdir(path): pathname = os.path.join(path, '**', self.filename_glob) @@ -410,7 +412,7 @@ def resampling(self) -> Resampling: def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -516,7 +518,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) + filepaths = cast(list[Path], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -559,7 +561,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _merge_files( self, - filepaths: Sequence[str], + filepaths: Sequence[Path], query: BoundingBox, band_indexes: Sequence[int] | None = None, ) -> Tensor: @@ -587,7 +589,7 @@ def _merge_files( return tensor @functools.lru_cache(maxsize=128) - def _cached_load_warp_file(self, filepath: str) -> DatasetReader: + def _cached_load_warp_file(self, filepath: Path) -> DatasetReader: """Cached version of :meth:`_load_warp_file`. Args: @@ -598,7 +600,7 @@ def _cached_load_warp_file(self, filepath: str) -> DatasetReader: """ return self._load_warp_file(filepath) - def _load_warp_file(self, filepath: str) -> DatasetReader: + def _load_warp_file(self, filepath: Path) -> DatasetReader: """Load and warp a file to the correct CRS and resolution. Args: @@ -649,7 +651,7 @@ def dtype(self) -> torch.dtype: def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.0001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -846,10 +848,10 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, - loader: Callable[[str], Any] | None = pil_loader, - is_valid_file: Callable[[str], bool] | None = None, + loader: Callable[[Path], Any] | None = pil_loader, + is_valid_file: Callable[[Path], bool] | None = None, ) -> None: """Initialize a new NonGeoClassificationDataset instance. diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 329d488e94d..078b83e6054 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import Path, download_and_extract_archive class GID15(NonGeoDataset): @@ -88,7 +88,7 @@ class GID15(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -154,7 +154,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -178,7 +178,7 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -195,7 +195,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 17091b6cc3d..fc15e918dbd 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -5,6 +5,7 @@ import glob import os +import pathlib from collections.abc import Callable, Iterable from typing import Any, cast @@ -15,7 +16,13 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, check_integrity, disambiguate_timestamp, extract_archive +from .utils import ( + BoundingBox, + Path, + check_integrity, + disambiguate_timestamp, + extract_archive, +) class GlobBiomass(RasterDataset): @@ -131,7 +138,7 @@ class GlobBiomass(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, measurement: str = 'agb', @@ -186,7 +193,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) + filepaths = cast(list[Path], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -195,7 +202,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = self._merge_files(filepaths, query) - std_error_paths = [f.replace('.tif', '_err.tif') for f in filepaths] + std_error_paths = [str(f).replace('.tif', '_err.tif') for f in filepaths] std_err_mask = self._merge_files(std_error_paths, query) mask = torch.cat((mask, std_err_mask), dim=0) @@ -214,7 +221,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, f'*_{self.measurement}.zip') if glob.glob(pathname): for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 4dd067244df..67a9db80e87 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -22,7 +22,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, lazy_import +from .utils import Path, download_url, extract_archive, lazy_import class IDTReeS(NonGeoDataset): @@ -152,7 +152,7 @@ class IDTReeS(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', task: str = 'task1', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -240,7 +240,7 @@ def __len__(self) -> int: """ return len(self.images) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a tiff file. Args: @@ -254,7 +254,7 @@ def _load_image(self, path: str) -> Tensor: tensor = torch.from_numpy(array) return tensor - def _load_las(self, path: str) -> Tensor: + def _load_las(self, path: Path) -> Tensor: """Load a single point cloud. Args: @@ -269,7 +269,7 @@ def _load_las(self, path: str) -> Tensor: tensor = torch.from_numpy(array) return tensor - def _load_boxes(self, path: str) -> Tensor: + def _load_boxes(self, path: Path) -> Tensor: """Load object bounding boxes. Args: @@ -313,7 +313,7 @@ def _load_boxes(self, path: str) -> Tensor: tensor = torch.tensor(boxes) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load target label for a single sample. Args: @@ -333,7 +333,7 @@ def _load_target(self, path: str) -> Tensor: return tensor def _load( - self, root: str + self, root: Path ) -> tuple[list[str], dict[int, dict[str, Any]] | None, Any]: """Load files, geometries, and labels. @@ -360,7 +360,7 @@ def _load( return images, geoms, labels - def _load_labels(self, directory: str) -> Any: + def _load_labels(self, directory: Path) -> Any: """Load the csv files containing the labels. Args: @@ -380,7 +380,7 @@ def _load_labels(self, directory: str) -> Any: df.reset_index() return df - def _load_geometries(self, directory: str) -> dict[int, dict[str, Any]]: + def _load_geometries(self, directory: Path) -> dict[int, dict[str, Any]]: """Load the shape files containing the geometries. Args: diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index 478b60a1c10..06aa3e6185f 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, disambiguate_timestamp +from .utils import BoundingBox, Path, disambiguate_timestamp class INaturalist(GeoDataset): @@ -34,7 +34,7 @@ class INaturalist(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = 'data') -> None: + def __init__(self, root: Path = 'data') -> None: """Initialize a new Dataset instance. Args: diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 5b3db228499..3b2a4348a96 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive, percentile_normalization +from .utils import Path, check_integrity, extract_archive, percentile_normalization class InriaAerialImageLabeling(NonGeoDataset): @@ -59,7 +59,7 @@ class InriaAerialImageLabeling(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, checksum: bool = False, @@ -86,7 +86,7 @@ def __init__( self._verify() self.files = self._load_files(root) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -121,7 +121,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -135,7 +135,7 @@ def _load_image(self, path: str) -> Tensor: tensor = torch.from_numpy(array).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Loads the target mask. Args: diff --git a/torchgeo/datasets/iobench.py b/torchgeo/datasets/iobench.py index a0ee246065a..80376df9579 100644 --- a/torchgeo/datasets/iobench.py +++ b/torchgeo/datasets/iobench.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset from .landsat import Landsat9 -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class IOBench(IntersectionDataset): @@ -50,7 +50,7 @@ class IOBench(IntersectionDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'preprocessed', crs: CRS | None = None, res: float | None = None, diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 7153738b391..8adc3842c45 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -5,6 +5,7 @@ import glob import os +import pathlib import re from collections.abc import Callable, Iterable, Sequence from typing import Any, cast @@ -18,7 +19,13 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset, RasterDataset -from .utils import BoundingBox, disambiguate_timestamp, download_url, extract_archive +from .utils import ( + BoundingBox, + Path, + disambiguate_timestamp, + download_url, + extract_archive, +) class L7IrishImage(RasterDataset): @@ -61,7 +68,7 @@ class L7IrishMask(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -87,7 +94,7 @@ def __init__( filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE) index = Index(interleaved=False, properties=Property(dimension=3)) for hit in self.index.intersection(self.index.bounds, objects=True): - dirname = os.path.dirname(cast(str, hit.object)) + dirname = os.path.dirname(cast(Path, hit.object)) image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0] minx, maxx, miny, maxy, mint, maxt = hit.bounds if match := re.match(filename_regex, os.path.basename(image)): @@ -169,7 +176,7 @@ class L7Irish(IntersectionDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = L7IrishImage.all_bands, @@ -222,7 +229,7 @@ def _merge_dataset_indices(self) -> None: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if not isinstance(self.paths, str): + if not isinstance(self.paths, str | pathlib.Path): return for classname in [L7IrishImage, L7IrishMask]: @@ -255,7 +262,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index c200b5c63bc..4865ec932b4 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -5,6 +5,7 @@ import glob import os +import pathlib from collections.abc import Callable, Iterable, Sequence from typing import Any @@ -16,7 +17,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class L8BiomeImage(RasterDataset): @@ -132,7 +133,7 @@ class L8Biome(IntersectionDataset): def __init__( self, - paths: str | Iterable[str], + paths: Path | Iterable[Path], crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = L8BiomeImage.all_bands, @@ -173,7 +174,7 @@ def __init__( def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if not isinstance(self.paths, str): + if not isinstance(self.paths, str | pathlib.Path): return for classname in [L8BiomeImage, L8BiomeMask]: @@ -206,7 +207,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 970a45eb1cd..e6bdd34d1c9 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -23,7 +23,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive, working_dir +from .utils import BoundingBox, Path, download_url, extract_archive, working_dir class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): @@ -74,7 +74,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): } def __init__( - self, root: str = 'data', download: bool = False, checksum: bool = False + self, root: Path = 'data', download: bool = False, checksum: bool = False ) -> None: """Initialize a new LandCover.ai dataset instance. @@ -205,7 +205,7 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -254,8 +254,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - img_filepaths = cast(list[str], [hit.object for hit in hits]) - mask_filepaths = [path.replace('images', 'masks') for path in img_filepaths] + img_filepaths = cast(list[Path], [hit.object for hit in hits]) + mask_filepaths = [ + str(path).replace('images', 'masks') for path in img_filepaths + ] if not img_filepaths: raise IndexError( @@ -294,7 +296,7 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index aee28c1224d..48647f5d247 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -13,6 +13,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset +from .utils import Path class Landsat(RasterDataset, abc.ABC): @@ -59,7 +60,7 @@ def default_bands(self) -> list[str]: def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 67209f603a8..9dbc68136db 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive, percentile_normalization +from .utils import Path, download_and_extract_archive, percentile_normalization class LEVIRCDBase(NonGeoDataset, abc.ABC): @@ -31,7 +31,7 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -94,7 +94,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -111,7 +111,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: @@ -183,7 +183,7 @@ def plot( return fig @abc.abstractmethod - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -255,7 +255,7 @@ class LEVIRCD(LEVIRCDBase): }, } - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -338,7 +338,7 @@ class LEVIRCDPlus(LEVIRCDBase): directory = 'LEVIR-CD+' splits = ['train', 'test'] - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 58f3876a09b..8c987548f90 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import Path, download_and_extract_archive class LoveDA(NonGeoDataset): @@ -91,7 +91,7 @@ class LoveDA(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', scene: list[str] = ['urban', 'rural'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -197,7 +197,7 @@ def _load_files(self, scene_paths: list[str], split: str) -> list[dict[str, str] return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -214,7 +214,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load a single mask corresponding to image. Args: diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index 882ec260fef..66ce23ad70f 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -19,6 +19,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, download_url, extract_archive, @@ -108,7 +109,7 @@ class MapInWild(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', modality: list[str] = ['mask', 'esa_wc', 'viirs', 's2_summer'], split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -205,7 +206,7 @@ def __len__(self) -> int: """ return len(self.ids) - def _load_raster(self, filename: int, source: str) -> Tensor: + def _load_raster(self, filename: int, source: Path) -> Tensor: """Load a single raster image or target. Args: @@ -272,7 +273,7 @@ def _download(self, url: str, md5: str | None) -> None: md5=md5 if self.checksum else None, ) - def _extract(self, path: str) -> None: + def _extract(self, path: Path) -> None: """Extracts a modality. Args: diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 46eabbe19e9..1938aefbafc 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive +from .utils import Path, check_integrity, extract_archive class MillionAID(NonGeoDataset): @@ -190,7 +190,7 @@ class MillionAID(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', task: str = 'multi-class', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -252,7 +252,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: return sample - def _load_files(self, root: str) -> list[dict[str, Any]]: + def _load_files(self, root: Path) -> list[dict[str, Any]]: """Return the paths of the files in the dataset. Args: @@ -295,7 +295,7 @@ def _load_files(self, root: str) -> list[dict[str, Any]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 6d7a2dcdaa5..66a2f5789ba 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -16,7 +16,12 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import ( + Path, + check_integrity, + download_radiant_mlhub_collection, + extract_archive, +) class NASAMarineDebris(NonGeoDataset): @@ -61,7 +66,7 @@ class NASAMarineDebris(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, api_key: str | None = None, @@ -123,7 +128,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -137,7 +142,7 @@ def _load_image(self, path: str) -> Tensor: tensor = torch.from_numpy(array).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target bounding boxes for a single image. Args: diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 68e0566e28a..83163391735 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, download_url +from .utils import BoundingBox, Path, download_url class NCCM(RasterDataset): @@ -83,7 +83,7 @@ class NCCM(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2019], diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 13c6883801d..e7113b6709f 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -5,6 +5,7 @@ import glob import os +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -15,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class NLCD(RasterDataset): @@ -108,7 +109,7 @@ class NLCD(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2019], @@ -191,7 +192,7 @@ def _verify(self) -> None: exists = [] for year in self.years: zipfile_year = self.zipfile_glob.replace('*', str(year), 1) - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, '**', zipfile_year) if glob.glob(pathname, recursive=True): exists.append(True) @@ -223,7 +224,7 @@ def _extract(self) -> None: """Extract the dataset.""" for year in self.years: zipfile_name = self.zipfile_glob.replace('*', str(year), 1) - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, '**', zipfile_name) extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index ec1650ed88f..f3500f35733 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -6,6 +6,7 @@ import glob import json import os +import pathlib import sys from collections.abc import Callable, Iterable from typing import Any, cast @@ -24,7 +25,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import BoundingBox, check_integrity +from .utils import BoundingBox, Path, check_integrity class OpenBuildings(VectorDataset): @@ -207,7 +208,7 @@ class OpenBuildings(VectorDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.0001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -241,7 +242,7 @@ def __init__( # Create an R-tree to index the dataset using the polygon centroid as bounds self.index = Index(interleaved=False, properties=Property(dimension=3)) - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) with open(os.path.join(self.paths, 'tiles.geojson')) as f: data = json.load(f) @@ -304,7 +305,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) + filepaths = cast(list[Path], [hit.object for hit in hits]) if not filepaths: raise IndexError( @@ -335,7 +336,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample def _filter_geometries( - self, query: BoundingBox, filepaths: list[str] + self, query: BoundingBox, filepaths: list[Path] ) -> list[dict[str, Any]]: """Filters a df read from the polygon csv file based on query and conf thresh. @@ -397,7 +398,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the zip files have already been downloaded and checksum - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) pathname = os.path.join(self.paths, self.zipfile_glob) i = 0 for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index b2ad8aef275..808ca93ab09 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -17,6 +17,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset from .utils import ( + Path, download_url, draw_semantic_segmentation_masks, extract_archive, @@ -98,7 +99,7 @@ class OSCD(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -207,7 +208,7 @@ def get_image_paths(ind: int) -> list[str]: return regions - def _load_image(self, paths: Sequence[str]) -> Tensor: + def _load_image(self, paths: Sequence[Path]) -> Tensor: """Load a single image. Args: @@ -224,7 +225,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor: tensor = torch.from_numpy(array).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 0b7629bcec5..bdad515f66b 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive class PASTIS(NonGeoDataset): @@ -128,7 +128,7 @@ class PASTIS(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', folds: Sequence[int] = (1, 2, 3, 4, 5), bands: str = 's2', mode: str = 'semantic', diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index 4b64a1b488e..a9385049872 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class PatternNet(NonGeoClassificationDataset): @@ -85,7 +85,7 @@ class PatternNet(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 943a489217a..479ca3cf170 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -17,6 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -121,7 +122,7 @@ class Potsdam2D(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index ce5d9a3bd2c..3fd2501dd0f 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, lazy_import, percentile_normalization +from .utils import Path, download_url, lazy_import, percentile_normalization class QuakeSet(NonGeoDataset): @@ -66,7 +66,7 @@ class QuakeSet(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 28c1d0135f0..bd28ab83f5d 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import Path, check_integrity, download_and_extract_archive, extract_archive class ReforesTree(NonGeoDataset): @@ -64,7 +64,7 @@ class ReforesTree(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -124,7 +124,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> list[str]: + def _load_files(self, root: Path) -> list[str]: """Return the paths of the files in the dataset. Args: @@ -137,7 +137,7 @@ def _load_files(self, root: str) -> list[str]: return image_paths - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -153,7 +153,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, filepath: str) -> tuple[Tensor, ...]: + def _load_target(self, filepath: Path) -> tuple[Tensor, ...]: """Load boxes and labels for a single image. Args: diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index fb066424b1a..9b99858a0c3 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class RESISC45(NonGeoClassificationDataset): @@ -119,7 +119,7 @@ class RESISC45(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -149,7 +149,7 @@ def __init__( for fn in f: valid_fns.add(fn.strip()) - def is_in_split(x: str) -> bool: + def is_in_split(x: Path) -> bool: return os.path.basename(x) in valid_fns super().__init__( diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 07a496ea974..510039d8dcd 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import which +from .utils import Path, which class RwandaFieldBoundary(NonGeoDataset): @@ -63,7 +63,7 @@ class RwandaFieldBoundary(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 1bd59487af3..6d19fed7fcc 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -20,7 +20,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, percentile_normalization +from .utils import Path, download_url, extract_archive, percentile_normalization class SeasoNet(NonGeoDataset): @@ -207,7 +207,7 @@ class SeasoNet(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', seasons: Collection[str] = all_seasons, bands: Iterable[str] = all_bands, diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index ea36b974da1..74b8ebba54a 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, percentile_normalization +from .utils import Path, download_url, extract_archive, percentile_normalization class SeasonalContrastS2(NonGeoDataset): @@ -70,7 +70,7 @@ class SeasonalContrastS2(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', version: str = '100k', seasons: int = 1, bands: list[str] = rgb_bands, @@ -147,7 +147,7 @@ def __len__(self) -> int: """ return (10**5 if self.version == '100k' else 10**6) // 5 - def _load_patch(self, root: str, subdir: str) -> Tensor: + def _load_patch(self, root: Path, subdir: Path) -> Tensor: """Load a single image patch. Args: diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 07f49f964c1..8b8ee57803c 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, percentile_normalization +from .utils import Path, check_integrity, percentile_normalization class SEN12MS(NonGeoDataset): @@ -165,7 +165,7 @@ class SEN12MS(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 2d4dedb50fb..163c771a6b9 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -13,6 +13,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset +from .utils import Path class Sentinel(RasterDataset): @@ -141,7 +142,7 @@ class Sentinel1(Sentinel): def __init__( self, - paths: str | list[str] = 'data', + paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float = 10, bands: Sequence[str] = ['VV', 'VH'], @@ -297,7 +298,7 @@ class Sentinel2(Sentinel): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 10, bands: Sequence[str] | None = None, diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 0d111ae15b9..8a8882f3598 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, lazy_import +from .utils import Path, download_url, extract_archive, lazy_import class SKIPPD(NonGeoDataset): @@ -79,7 +79,7 @@ class SKIPPD(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'trainval', task: str = 'nowcast', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 3003031399b..e90e89b4e34 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, lazy_import, percentile_normalization +from .utils import Path, check_integrity, lazy_import, percentile_normalization class So2Sat(NonGeoDataset): @@ -194,7 +194,7 @@ class So2Sat(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', version: str = '2', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 48dcbd8529a..54ddf4299f2 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -4,6 +4,7 @@ """South Africa Crop Type Competition Dataset.""" import os +import pathlib import re from collections.abc import Callable, Iterable from typing import Any, cast @@ -16,7 +17,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset -from .utils import BoundingBox +from .utils import BoundingBox, Path class SouthAfricaCropType(RasterDataset): @@ -102,7 +103,7 @@ class SouthAfricaCropType(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: list[str] = s2_bands, @@ -148,11 +149,11 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Returns: data and labels at that index """ - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) # Get all files matching the given query hits = self.index.intersection(tuple(query), objects=True) - filepaths = cast(list[str], [hit.object for hit in hits]) + filepaths = cast(list[Path], [hit.object for hit in hits]) if not filepaths: raise IndexError( diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index fc28c229370..1ee8d7cb79e 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -3,6 +3,7 @@ """South America Soybean Dataset.""" +import pathlib from collections.abc import Callable, Iterable from typing import Any @@ -12,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import download_url +from .utils import Path, download_url class SouthAmericaSoybean(RasterDataset): @@ -72,7 +73,7 @@ class SouthAmericaSoybean(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2021], @@ -112,7 +113,7 @@ def _verify(self) -> None: # Check if the extracted files already exist if self.files: return - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | pathlib.Path) # Check if the user requested to download the dataset if not self.download: diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 2a5283dfb5b..525faff3433 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -28,6 +28,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, download_radiant_mlhub_collection, download_radiant_mlhub_dataset, @@ -79,7 +80,7 @@ def chip_size(self) -> dict[str, tuple[int, int]]: def __init__( self, - root: str, + root: Path, image: str, collections: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -124,7 +125,7 @@ def __init__( self.files = self._load_files(root) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -144,7 +145,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: files.append({'image_path': imgpath, 'label_path': lbl_path}) return files - def _load_image(self, path: str) -> tuple[Tensor, Affine, CRS]: + def _load_image(self, path: Path) -> tuple[Tensor, Affine, CRS]: """Load a single image. Args: @@ -160,7 +161,7 @@ def _load_image(self, path: str) -> tuple[Tensor, Affine, CRS]: return tensor, img.transform, img.crs def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] + self, path: Path, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] ) -> Tensor: """Rasterizes the dataset's labels (in geojson format). @@ -398,7 +399,7 @@ class SpaceNet1(SpaceNet): def __init__( self, - root: str = 'data', + root: Path = 'data', image: str = 'rgb', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -513,7 +514,7 @@ class SpaceNet2(SpaceNet): def __init__( self, - root: str = 'data', + root: Path = 'data', image: str = 'PS-RGB', collections: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -633,7 +634,7 @@ class SpaceNet3(SpaceNet): def __init__( self, - root: str = 'data', + root: Path = 'data', image: str = 'PS-RGB', speed_mask: bool | None = False, collections: list[str] = [], @@ -669,7 +670,7 @@ def __init__( ) def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] + self, path: Path, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] ) -> Tensor: """Rasterizes the dataset's labels (in geojson format). @@ -884,7 +885,7 @@ class SpaceNet4(SpaceNet): def __init__( self, - root: str = 'data', + root: Path = 'data', image: str = 'PS-RGBNIR', angles: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -918,7 +919,7 @@ def __init__( root, image, collections, transforms, download, api_key, checksum ) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -1052,7 +1053,7 @@ class SpaceNet5(SpaceNet3): def __init__( self, - root: str = 'data', + root: Path = 'data', image: str = 'PS-RGB', speed_mask: bool | None = False, collections: list[str] = [], @@ -1184,7 +1185,7 @@ class SpaceNet6(SpaceNet): def __init__( self, - root: str = 'data', + root: Path = 'data', image: str = 'PS-RGB', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -1282,7 +1283,7 @@ class SpaceNet7(SpaceNet): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -1326,7 +1327,7 @@ def __init__( self.files = self._load_files(root) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index ce6558f52c4..a087cbf68dc 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive class SSL4EO(NonGeoDataset): @@ -162,7 +162,7 @@ class _Metadata(TypedDict): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'oli_sr', seasons: int = 1, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -404,7 +404,7 @@ class _Metadata(TypedDict): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 's2c', seasons: int = 1, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 7d9edcaecb4..3f16bf33b07 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .nlcd import NLCD -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class SSL4EOLBenchmark(NonGeoDataset): @@ -107,7 +107,7 @@ class SSL4EOLBenchmark(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', sensor: str = 'oli_sr', product: str = 'cdl', split: str = 'train', @@ -297,7 +297,7 @@ def retrieve_sample_collection(self) -> list[tuple[str, str]]: sample_collection.append((img_path, mask_path)) return sample_collection - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load the input image. Args: @@ -310,7 +310,7 @@ def _load_image(self, path: str) -> Tensor: image = torch.from_numpy(src.read()).float() return image - def _load_mask(self, path: str) -> Tensor: + def _load_mask(self, path: Path) -> Tensor: """Load the mask. Args: diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 8eb410297e9..604c17bc70e 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class SustainBenchCropYield(NonGeoDataset): @@ -59,7 +59,7 @@ class SustainBenchCropYield(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', countries: list[str] = ['usa'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 686fb3b96f8..7045ee558a8 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive class UCMerced(NonGeoClassificationDataset): @@ -86,7 +86,7 @@ class UCMerced(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -117,7 +117,7 @@ def __init__( for fn in f: valid_fns.add(fn.strip()) - def is_in_split(x: str) -> bool: + def is_in_split(x: Path) -> bool: return os.path.basename(x) in valid_fns super().__init__( diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b955e8ded68..42443c56d2d 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class USAVars(NonGeoDataset): @@ -86,7 +86,7 @@ class USAVars(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', labels: Sequence[str] = ALL_LABELS, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -170,7 +170,7 @@ def _load_files(self) -> list[str]: files = f.read().splitlines() return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 12ed2aad40b..2f72916982d 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -13,6 +13,7 @@ import importlib import lzma import os +import pathlib import shutil import subprocess import sys @@ -20,7 +21,7 @@ from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, cast, overload +from typing import Any, TypeAlias, cast, overload import numpy as np import rasterio @@ -35,6 +36,9 @@ __all__ = ('check_integrity', 'download_url') +Path: TypeAlias = str | pathlib.Path + + class _rarfile: class RarFile: def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -72,7 +76,7 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: pass -def extract_archive(src: str, dst: str | None = None) -> None: +def extract_archive(src: Path, dst: Path | None = None) -> None: """Extract an archive. Args: @@ -95,7 +99,7 @@ def extract_archive(src: str, dst: str | None = None) -> None: ] for suffix, extractor in suffix_and_extractor: - if src.endswith(suffix): + if str(src).endswith(suffix): with extractor(src, 'r') as f: f.extractall(dst) return @@ -107,7 +111,7 @@ def extract_archive(src: str, dst: str | None = None) -> None: ] for suffix, decompressor in suffix_and_decompressor: - if src.endswith(suffix): + if str(src).endswith(suffix): dst = os.path.join(dst, os.path.basename(src).replace(suffix, '')) with decompressor(src, 'rb') as sf, open(dst, 'wb') as df: df.write(sf.read()) @@ -118,9 +122,9 @@ def extract_archive(src: str, dst: str | None = None) -> None: def download_and_extract_archive( url: str, - download_root: str, - extract_root: str | None = None, - filename: str | None = None, + download_root: Path, + extract_root: Path | None = None, + filename: Path | None = None, md5: str | None = None, ) -> None: """Download and extract an archive. @@ -146,7 +150,7 @@ def download_and_extract_archive( def download_radiant_mlhub_dataset( - dataset_id: str, download_root: str, api_key: str | None = None + dataset_id: str, download_root: Path, api_key: str | None = None ) -> None: """Download a dataset from Radiant Earth. @@ -166,7 +170,7 @@ def download_radiant_mlhub_dataset( def download_radiant_mlhub_collection( - collection_id: str, download_root: str, api_key: str | None = None + collection_id: str, download_root: Path, api_key: str | None = None ) -> None: """Download a collection from Radiant Earth. @@ -410,7 +414,7 @@ class Executable: .. versionadded:: 0.6 """ - def __init__(self, name: str) -> None: + def __init__(self, name: Path) -> None: """Initialize a new Executable instance. Args: @@ -488,7 +492,7 @@ def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: @contextlib.contextmanager -def working_dir(dirname: str, create: bool = False) -> Iterator[None]: +def working_dir(dirname: Path, create: bool = False) -> Iterator[None]: """Context manager for changing directories. Args: @@ -633,7 +637,7 @@ def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: return _dict_list_to_list_dict(sample) -def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: +def rasterio_loader(path: Path) -> np.typing.NDArray[np.int_]: """Load an image file using rasterio. Args: @@ -649,7 +653,7 @@ def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: return array -def sort_sentinel2_bands(x: str) -> str: +def sort_sentinel2_bands(x: Path) -> str: """Sort Sentinel-2 band files in the correct order.""" x = os.path.basename(x).split('_')[-1] x = os.path.splitext(x)[0] @@ -744,7 +748,7 @@ def percentile_normalization( return img_normalized -def path_is_vsi(path: str) -> bool: +def path_is_vsi(path: Path) -> bool: """Checks if the given path is pointing to a Virtual File System. .. note:: @@ -758,14 +762,14 @@ def path_is_vsi(path: str) -> bool: * https://rasterio.readthedocs.io/en/latest/topics/datasets.html Args: - path: string representing a directory or file + path: a directory or file Returns: True if path is on a virtual file system, else False .. versionadded:: 0.6 """ - return '://' in path or path.startswith('/vsi') + return '://' in str(path) or str(path).startswith('/vsi') def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor: @@ -831,7 +835,7 @@ def lazy_import(name: str) -> Any: raise DependencyNotFoundError(msg) from None -def which(name: str) -> Executable: +def which(name: Path) -> Executable: """Search for executable *name*. Args: diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 6276dcb87cf..305eb950197 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -16,6 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -120,7 +121,7 @@ class Vaihingen2D(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index b1aae5d2a30..14793b3d44d 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -18,6 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, download_and_extract_archive, download_url, @@ -186,7 +187,7 @@ class VHR10(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'positive', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index 10602efc256..1f389154736 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_radiant_mlhub_collection, extract_archive +from .utils import Path, download_radiant_mlhub_collection, extract_archive class WesternUSALiveFuelMoisture(NonGeoDataset): @@ -200,7 +200,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', input_features: list[str] = all_variable_names, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..0cb66cba9f3 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -16,7 +16,12 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive +from .utils import ( + Path, + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, +) class XView2(NonGeoDataset): @@ -66,7 +71,7 @@ class XView2(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -127,7 +132,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -152,7 +157,7 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -169,7 +174,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index e1a5f4d2870..394721b5237 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, lazy_import, percentile_normalization +from .utils import Path, download_url, lazy_import, percentile_normalization class ZueriCrop(NonGeoDataset): @@ -64,7 +64,7 @@ class ZueriCrop(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', bands: Sequence[str] = band_names, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False,