Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for pathlib #2154

Closed
wants to merge 14 commits into from
6 changes: 3 additions & 3 deletions torchgeo/datasets/advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_and_extract_archive, lazy_import
from .utils import Path, download_and_extract_archive, lazy_import


class ADVANCE(NonGeoDataset):
Expand Down Expand Up @@ -88,7 +88,7 @@ class ADVANCE(NonGeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
Expand Down Expand Up @@ -151,7 +151,7 @@ def __len__(self) -> int:
"""
return len(self.files)

def _load_files(self, root: str) -> list[dict[str, str]]:
def _load_files(self, root: Path) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset.

Args:
Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand All @@ -14,7 +15,7 @@

from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import download_url
from .utils import Path, download_url


class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
Expand Down Expand Up @@ -57,7 +58,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):

def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down Expand Up @@ -105,7 +106,7 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
download_url(self.url, self.paths, self.base_filename)

with open(os.path.join(self.paths, self.base_filename)) as f:
Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""AgriFieldNet India Challenge dataset."""

import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
Expand All @@ -16,7 +17,7 @@

from .errors import RGBBandsMissingError
from .geo import RasterDataset
from .utils import BoundingBox
from .utils import BoundingBox, Path


class AgriFieldNet(RasterDataset):
Expand Down Expand Up @@ -115,7 +116,7 @@ class AgriFieldNet(RasterDataset):

def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
classes: list[int] = list(cmap.keys()),
bands: Sequence[str] = all_bands,
Expand Down Expand Up @@ -167,7 +168,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Returns:
data, label, and field ids at that index
"""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)

hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datasets/astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import Path


class AsterGDEM(RasterDataset):
Expand Down Expand Up @@ -47,7 +48,7 @@ class AsterGDEM(RasterDataset):

def __init__(
self,
paths: str | list[str] = 'data',
paths: Path | list[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import which
from .utils import Path, which


class BeninSmallHolderCashews(NonGeoDataset):
Expand Down Expand Up @@ -163,7 +163,7 @@ class BeninSmallHolderCashews(NonGeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
chip_size: int = 256,
stride: int = 128,
bands: Sequence[str] = all_bands,
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands
from .utils import Path, download_url, extract_archive, sort_sentinel2_bands


class BigEarthNet(NonGeoDataset):
Expand Down Expand Up @@ -267,7 +267,7 @@ class BigEarthNet(NonGeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
bands: str = 'all',
num_classes: int = 19,
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import percentile_normalization
from .utils import Path, percentile_normalization


class BioMassters(NonGeoDataset):
Expand Down Expand Up @@ -57,7 +57,7 @@ class BioMassters(NonGeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
sensors: Sequence[str] = ['S1', 'S2'],
as_time_series: bool = False,
Expand Down
9 changes: 5 additions & 4 deletions torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Canadian Building Footprints dataset."""

import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand All @@ -13,7 +14,7 @@

from .errors import DatasetNotFoundError
from .geo import VectorDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import Path, check_integrity, download_and_extract_archive


class CanadianBuildingFootprints(VectorDataset):
Expand Down Expand Up @@ -62,7 +63,7 @@ class CanadianBuildingFootprints(VectorDataset):

def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float = 0.00001,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down Expand Up @@ -104,7 +105,7 @@ def _check_integrity(self) -> bool:
Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.paths, prov_terr + '.zip')
if not check_integrity(filepath, md5 if self.checksum else None):
Expand All @@ -116,7 +117,7 @@ def _download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + '.zip',
Expand Down
11 changes: 6 additions & 5 deletions torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""CDL dataset."""

import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

Expand All @@ -14,7 +15,7 @@

from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, Path, download_url, extract_archive


class CDL(RasterDataset):
Expand Down Expand Up @@ -207,7 +208,7 @@ class CDL(RasterDataset):

def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
years: list[int] = [2023],
Expand Down Expand Up @@ -294,7 +295,7 @@ def _verify(self) -> None:

# Check if the zip files have already been downloaded
exists = []
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for year in self.years:
pathname = os.path.join(
self.paths, self.zipfile_glob.replace('*', str(year))
Expand Down Expand Up @@ -327,11 +328,11 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for year in self.years:
zipfile_name = self.zipfile_glob.replace('*', str(year))
pathname = os.path.join(self.paths, zipfile_name)
extract_archive(pathname, self.paths)
extract_archive(pathname, str(self.paths))

def plot(
self,
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/chabud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, lazy_import, percentile_normalization
from .utils import Path, download_url, lazy_import, percentile_normalization


class ChaBuD(NonGeoDataset):
Expand Down Expand Up @@ -75,7 +75,7 @@ class ChaBuD(NonGeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
bands: list[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
Expand Down
11 changes: 6 additions & 5 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
import os
import pathlib
import sys
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
Expand All @@ -26,7 +27,7 @@
from .errors import DatasetNotFoundError
from .geo import GeoDataset, RasterDataset
from .nlcd import NLCD
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, Path, download_url, extract_archive


class Chesapeake(RasterDataset, abc.ABC):
Expand Down Expand Up @@ -91,7 +92,7 @@ def url(self) -> str:

def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down Expand Up @@ -145,7 +146,7 @@ def _verify(self) -> None:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
if os.path.exists(os.path.join(self.paths, self.zipfile)):
self._extract()
return
Expand All @@ -164,7 +165,7 @@ def _download(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
extract_archive(os.path.join(self.paths, self.zipfile))

def plot(
Expand Down Expand Up @@ -510,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
splits: Sequence[str] = ['de-train'],
layers: Sequence[str] = ['naip-new', 'lc'],
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import which
from .utils import Path, which


class CloudCoverDetection(NonGeoDataset):
Expand Down Expand Up @@ -61,7 +61,7 @@ class CloudCoverDetection(NonGeoDataset):

def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
Expand Down
9 changes: 5 additions & 4 deletions torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""CMS Global Mangrove Canopy dataset."""

import os
import pathlib
from collections.abc import Callable
from typing import Any

Expand All @@ -13,7 +14,7 @@

from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import check_integrity, extract_archive
from .utils import Path, check_integrity, extract_archive


class CMSGlobalMangroveCanopy(RasterDataset):
Expand Down Expand Up @@ -169,7 +170,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):

def __init__(
self,
paths: str | list[str] = 'data',
paths: Path | list[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
measurement: str = 'agb',
Expand Down Expand Up @@ -228,7 +229,7 @@ def _verify(self) -> None:
return

# Check if the zip file has already been downloaded
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
pathname = os.path.join(self.paths, self.zipfile)
if os.path.exists(pathname):
if self.checksum and not check_integrity(pathname, self.md5):
Expand All @@ -240,7 +241,7 @@ def _verify(self) -> None:

def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
pathname = os.path.join(self.paths, self.zipfile)
extract_archive(pathname)

Expand Down
Loading
Loading