Skip to content

Commit

Permalink
Datasets: add support for pathlib.Path (#2173)
Browse files Browse the repository at this point in the history
* Added pathlib support for  ``geo.py``

* Fixed failing ruff checks.

* Fixed additional ruff fromating errors

* Added complete ``pathlib`` support

* Additional changes

* Fixed ``cyclone.py`` issues

* Additional fixes

* Fixed ``isinstance`` and ``Path`` inconsistency and

* Fixed mypy errors in ``cdl.py``

* geo/utils: all paths are Paths

* datasets: all paths are Paths

* Test Paths

* Type checks only work for latest torchvision

* Fix tests

---------

Co-authored-by: Hitesh Tolani <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent 67eaeba commit 44ce007
Show file tree
Hide file tree
Showing 158 changed files with 658 additions and 598 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ repos:
- scikit-image>=0.22.0
- torch>=2.3
- torchmetrics>=0.10
- torchvision>=0.18
exclude: (build|data|dist|logo|logs|output)/
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
pytest.importorskip('scipy', minversion='1.7.2')


def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)


Expand All @@ -33,7 +33,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE:
md5s = ['43acacecebecd17a82bc2c1e719fd7e4', '039b7baa47879a8a4e32b9dd8287f6ad']
monkeypatch.setattr(ADVANCE, 'urls', urls)
monkeypatch.setattr(ADVANCE, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ADVANCE(root, transforms, download=True, checksum=True)

Expand All @@ -57,7 +57,7 @@ def test_already_downloaded(self, dataset: ADVANCE) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ADVANCE(str(tmp_path))
ADVANCE(tmp_path)

def test_plot(self, dataset: ADVANCE) -> None:
x = dataset[0].copy()
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


Expand All @@ -42,7 +42,7 @@ def dataset(
)
monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, 'url', url)

root = str(tmp_path)
root = tmp_path
return AbovegroundLiveWoodyBiomassDensity(
root, transforms=transforms, download=True
)
Expand All @@ -58,7 +58,7 @@ def test_len(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:

def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
AbovegroundLiveWoodyBiomassDensity(tmp_path)

def test_already_downloaded(
self, dataset: AbovegroundLiveWoodyBiomassDensity
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_already_downloaded(self, dataset: AgriFieldNet) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AgriFieldNet(str(tmp_path))
AgriFieldNet(tmp_path)

def test_plot(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_airphen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_plot(self, dataset: Airphen) -> None:

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Airphen(str(tmp_path))
Airphen(tmp_path)

def test_invalid_query(self, dataset: Airphen) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class TestAsterGDEM:
def dataset(self, tmp_path: Path) -> AsterGDEM:
zipfile = os.path.join('tests', 'data', 'astergdem', 'astergdem.zip')
shutil.unpack_archive(zipfile, tmp_path, 'zip')
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return AsterGDEM(root, transforms=transforms)

def test_datasetmissing(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AsterGDEM(str(tmp_path))
AsterGDEM(tmp_path)

def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def dataset(
monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('20191105',))
monkeypatch.setattr(BeninSmallHolderCashews, 'tile_height', 2)
monkeypatch.setattr(BeninSmallHolderCashews, 'tile_width', 2)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return BeninSmallHolderCashews(root, transforms=transforms, download=True)

Expand All @@ -54,7 +54,7 @@ def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BeninSmallHolderCashews(str(tmp_path))
BeninSmallHolderCashews(tmp_path)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
Expand Down
18 changes: 9 additions & 9 deletions tests/datasets/test_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchgeo.datasets import BigEarthNet, DatasetNotFoundError


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


Expand Down Expand Up @@ -63,7 +63,7 @@ def dataset(
monkeypatch.setattr(BigEarthNet, 'metadata', metadata)
monkeypatch.setattr(BigEarthNet, 'splits_metadata', splits_metadata)
bands, num_classes, split = request.param
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return BigEarthNet(
root, split, bands, num_classes, transforms, download=True, checksum=True
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_len(self, dataset: BigEarthNet) -> None:

def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
BigEarthNet(
root=str(tmp_path),
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
Expand All @@ -112,21 +112,21 @@ def test_already_downloaded_not_extracted(
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata['s2']['directory'])
)
download_url(dataset.metadata['s1']['url'], root=str(tmp_path))
download_url(dataset.metadata['s2']['url'], root=str(tmp_path))
download_url(dataset.metadata['s1']['url'], root=tmp_path)
download_url(dataset.metadata['s2']['url'], root=tmp_path)
elif dataset.bands == 's1':
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata['s1']['directory'])
)
download_url(dataset.metadata['s1']['url'], root=str(tmp_path))
download_url(dataset.metadata['s1']['url'], root=tmp_path)
else:
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata['s2']['directory'])
)
download_url(dataset.metadata['s2']['url'], root=str(tmp_path))
download_url(dataset.metadata['s2']['url'], root=tmp_path)

BigEarthNet(
root=str(tmp_path),
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
Expand All @@ -135,7 +135,7 @@ def test_already_downloaded_not_extracted(

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BigEarthNet(str(tmp_path))
BigEarthNet(tmp_path)

def test_plot(self, dataset: BigEarthNet) -> None:
x = dataset[0].copy()
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_invalid_bands(self, dataset: BioMassters) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BioMassters(str(tmp_path))
BioMassters(tmp_path)

def test_plot(self, dataset: BioMassters) -> None:
dataset.plot(dataset[0], suptitle='Test')
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)


def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)


Expand All @@ -41,7 +41,7 @@ def dataset(
url = os.path.join('tests', 'data', 'cbf') + os.sep
monkeypatch.setattr(CanadianBuildingFootprints, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return CanadianBuildingFootprints(
root, res=0.1, transforms=transforms, download=True, checksum=True
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CanadianBuildingFootprints(str(tmp_path))
CanadianBuildingFootprints(tmp_path)

def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
query = BoundingBox(2, 2, 2, 2, 2, 2)
Expand Down
10 changes: 5 additions & 5 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


Expand All @@ -41,7 +41,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
url = os.path.join('tests', 'data', 'cdl', '{}_30m_cdls.zip')
monkeypatch.setattr(CDL, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return CDL(
root,
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_already_extracted(self, dataset: CDL) -> None:

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'cdl', '*_30m_cdls.zip')
root = str(tmp_path)
root = tmp_path
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
CDL(root, years=[2023, 2022])
Expand All @@ -97,7 +97,7 @@ def test_invalid_year(self, tmp_path: Path) -> None:
AssertionError,
match='CDL data product only exists for the following years:',
):
CDL(str(tmp_path), years=[1996])
CDL(tmp_path, years=[1996])

def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
Expand All @@ -121,7 +121,7 @@ def test_plot_prediction(self, dataset: CDL) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CDL(str(tmp_path))
CDL(tmp_path)

def test_invalid_query(self, dataset: CDL) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
8 changes: 5 additions & 3 deletions tests/datasets/test_chabud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
pytest.importorskip('h5py', minversion='3.6')


def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
def download_url(
url: str, root: str | Path, filename: str, *args: str, **kwargs: str
) -> None:
shutil.copy(url, os.path.join(root, filename))


Expand All @@ -34,7 +36,7 @@ def dataset(
monkeypatch.setattr(ChaBuD, 'url', url)
monkeypatch.setattr(ChaBuD, 'md5', md5)
bands, split = request.param
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ChaBuD(
root=root,
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_already_downloaded(self, dataset: ChaBuD) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChaBuD(str(tmp_path))
ChaBuD(tmp_path)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
Expand Down
14 changes: 7 additions & 7 deletions tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
pytest.importorskip('zipfile_deflate64')


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


Expand All @@ -41,7 +41,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13:
)
monkeypatch.setattr(Chesapeake13, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)

Expand Down Expand Up @@ -69,13 +69,13 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join(
'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip'
)
root = str(tmp_path)
root = tmp_path
shutil.copy(url, root)
Chesapeake13(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Chesapeake13(str(tmp_path), checksum=True)
Chesapeake13(tmp_path, checksum=True)

def test_plot(self, dataset: Chesapeake13) -> None:
query = dataset.bounds
Expand Down Expand Up @@ -148,7 +148,7 @@ def dataset(
'_files',
['de_1m_2013_extended-debuffered-test_tiles', 'spatial_index.geojson'],
)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ChesapeakeCVPR(
root,
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
ChesapeakeCVPR(root=dataset.root, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
root = str(tmp_path)
root = tmp_path
shutil.copy(
os.path.join(
'tests', 'data', 'chesapeake', 'cvpr', 'cvpr_chesapeake_landcover.zip'
Expand All @@ -201,7 +201,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChesapeakeCVPR(str(tmp_path), checksum=True)
ChesapeakeCVPR(tmp_path, checksum=True)

def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def dataset(
) -> CloudCoverDetection:
url = os.path.join('tests', 'data', 'ref_cloud_cover_detection_challenge_v1')
monkeypatch.setattr(CloudCoverDetection, 'url', url)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return CloudCoverDetection(
Expand All @@ -55,7 +55,7 @@ def test_already_downloaded(self, dataset: CloudCoverDetection) -> None:

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CloudCoverDetection(str(tmp_path))
CloudCoverDetection(tmp_path)

def test_plot(self, dataset: CloudCoverDetection) -> None:
sample = dataset[0]
Expand Down
Loading

0 comments on commit 44ce007

Please sign in to comment.