Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datasets: support os.PathLike #2273

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
paths: str | Path | Iterable[str | Path] | None = None,
paths: str | os.PathLike[str] | Iterable[str | os.PathLike[str]] | None = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import json
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -106,7 +105,7 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
download_url(self.url, self.paths, self.base_filename)

with open(os.path.join(self.paths, self.base_filename)) as f:
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""AgriFieldNet India Challenge dataset."""

import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar, cast
Expand Down Expand Up @@ -181,10 +180,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Returns:
data, label, and field ids at that index
"""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)

hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = cast(list[str], [hit.object for hit in hits])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of #2270, we can reasonably expect hit.object to be a str by default.


if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -246,7 +245,7 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
os.makedirs(self.paths, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', f'{self.url}', self.paths, '--recursive=true')
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""Canadian Building Footprints dataset."""

import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -105,7 +104,7 @@ def _check_integrity(self) -> bool:
Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.paths, prov_terr + '.zip')
if not check_integrity(filepath, md5 if self.checksum else None):
Expand All @@ -117,7 +116,7 @@ def _download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + '.zip',
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""CDL dataset."""

import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar

Expand Down Expand Up @@ -295,7 +294,7 @@ def _verify(self) -> None:

# Check if the zip files have already been downloaded
exists = []
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for year in self.years:
pathname = os.path.join(
self.paths, self.zipfile_glob.replace('*', str(year))
Expand Down Expand Up @@ -328,7 +327,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for year in self.years:
zipfile_name = self.zipfile_glob.replace('*', str(year))
pathname = os.path.join(self.paths, zipfile_name)
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
import sys
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
Expand Down Expand Up @@ -173,7 +172,7 @@ def _verify(self) -> None:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
if glob.glob(os.path.join(self.paths, '**', '*.zip'), recursive=True):
self._extract()
return
Expand All @@ -195,7 +194,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
for file in glob.iglob(os.path.join(self.paths, '**', '*.zip'), recursive=True):
extract_archive(file)

Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""CMS Global Mangrove Canopy dataset."""

import os
import pathlib
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -229,7 +228,7 @@ def _verify(self) -> None:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile)
if os.path.exists(pathname):
if self.checksum and not check_integrity(pathname, self.md5):
Expand All @@ -241,7 +240,7 @@ def _verify(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile)
extract_archive(pathname)

Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/esri2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -113,7 +112,7 @@ def _verify(self) -> None:
return

# Check if the zip files have already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile)
if glob.glob(pathname):
self._extract()
Expand All @@ -133,7 +132,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
extract_archive(os.path.join(self.paths, self.zipfile))

def plot(
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/datasets/eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar

Expand Down Expand Up @@ -117,7 +116,7 @@ def _verify(self) -> None:
return

# Check if the zip files have already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, self.zipfile_glob)
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import csv
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand Down Expand Up @@ -140,7 +139,7 @@ def _check_integrity(self) -> bool:
if self.files and not self.checksum:
return True

assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)

filepath = os.path.join(self.paths, self.hcat_fname)
if not check_integrity(filepath, self.hcat_md5 if self.checksum else None):
Expand All @@ -157,7 +156,7 @@ def _download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
download_url(
self.base_url + self.hcat_fname,
self.paths,
Expand All @@ -179,7 +178,7 @@ def _load_class_map(self, classes: list[str] | None) -> None:
(defaults to all classes)
"""
if not classes:
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
classes = []
filepath = os.path.join(self.paths, self.hcat_fname)
with open(filepath) as f:
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import functools
import glob
import os
import pathlib
import re
import sys
import warnings
Expand Down Expand Up @@ -300,7 +299,7 @@ def files(self) -> list[str]:
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str | pathlib.Path):
if isinstance(self.paths, str | os.PathLike):
paths: Iterable[Path] = [self.paths]
else:
paths = self.paths
Expand Down Expand Up @@ -521,7 +520,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -564,7 +563,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:

def _merge_files(
self,
filepaths: Sequence[Path],
filepaths: Sequence[str],
query: BoundingBox,
band_indexes: Sequence[int] | None = None,
) -> Tensor:
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/globbiomass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar, cast

Expand Down Expand Up @@ -193,7 +192,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = cast(list[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -221,7 +220,7 @@ def _verify(self) -> None:
return

# Check if the zip files have already been downloaded
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, f'*_{self.measurement}.zip')
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datasets/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar, cast
Expand Down Expand Up @@ -94,7 +93,7 @@ def __init__(
filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE)
index = Index(interleaved=False, properties=Property(dimension=3))
for hit in self.index.intersection(self.index.bounds, objects=True):
dirname = os.path.dirname(cast(Path, hit.object))
dirname = os.path.dirname(cast(str, hit.object))
image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0]
minx, maxx, miny, maxy, mint, maxt = hit.bounds
if match := re.match(filename_regex, os.path.basename(image)):
Expand Down Expand Up @@ -229,7 +228,7 @@ def _merge_dataset_indices(self) -> None:
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if not isinstance(self.paths, str | pathlib.Path):
if not isinstance(self.paths, str | os.PathLike):
return

for classname in [L7IrishImage, L7IrishMask]:
Expand Down Expand Up @@ -262,7 +261,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '*.tar.gz')
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar

Expand Down Expand Up @@ -174,7 +173,7 @@ def __init__(
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if not isinstance(self.paths, str | pathlib.Path):
if not isinstance(self.paths, str | os.PathLike):
return

for classname in [L8BiomeImage, L8BiomeMask]:
Expand Down Expand Up @@ -207,7 +206,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '*.tar.gz')
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
img_filepaths = cast(list[Path], [hit.object for hit in hits])
img_filepaths = cast(list[str], [hit.object for hit in hits])
mask_filepaths = [
str(path).replace('images', 'masks') for path in img_filepaths
]
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/nlcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, ClassVar

Expand Down Expand Up @@ -192,7 +191,7 @@ def _verify(self) -> None:
exists = []
for year in self.years:
zipfile_year = self.zipfile_glob.replace('*', str(year), 1)
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '**', zipfile_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
Expand Down Expand Up @@ -224,7 +223,7 @@ def _extract(self) -> None:
"""Extract the dataset."""
for year in self.years:
zipfile_name = self.zipfile_glob.replace('*', str(year), 1)
assert isinstance(self.paths, str | pathlib.Path)
assert isinstance(self.paths, str | os.PathLike)
pathname = os.path.join(self.paths, '**', zipfile_name)
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)

Expand Down
Loading