From c73cfe6f350a8be0a65e080563160b88adc99d5e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 12:58:46 +0100 Subject: [PATCH 1/4] Drop support for Python 3.10 --- .github/workflows/tests.yaml | 4 +- pyproject.toml | 65 +++++++++++++------------- requirements/min-reqs.old | 30 ++++++------ tests/datasets/test_advance.py | 2 +- tests/datasets/test_cabuar.py | 2 +- tests/datasets/test_chabud.py | 2 +- tests/datasets/test_cropharvest.py | 2 +- tests/datasets/test_digital_typhoon.py | 2 +- tests/datasets/test_landcoverai.py | 2 +- tests/datasets/test_mmearth.py | 2 +- tests/datasets/test_quakeset.py | 2 +- tests/datasets/test_skippd.py | 2 +- tests/datasets/test_so2sat.py | 6 +-- tests/datasets/test_vhr10.py | 2 +- tests/datasets/test_zuericrop.py | 2 +- 15 files changed, 62 insertions(+), 65 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 96fdb226f3c..900e9ea7a96 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -18,7 +18,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.10', '3.11', '3.12', '3.13'] + python-version: ['3.11', '3.12', '3.13'] steps: - name: Clone repo uses: actions/checkout@v4.2.2 @@ -59,7 +59,7 @@ jobs: id: setup-python uses: actions/setup-python@v5.3.0 with: - python-version: '3.10' + python-version: '3.11' - name: Cache dependencies uses: actions/cache@v4.2.0 id: cache diff --git a/pyproject.toml b/pyproject.toml index df0c51cd2b7..55f2134f07f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" name = "torchgeo" description = "TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = {file = "LICENSE"} authors = [ {name = "Adam J. Stewart", email = "ajstewart426@gmail.com"}, @@ -29,7 +29,6 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -39,8 +38,8 @@ classifiers = [ dependencies = [ # einops 0.3+ required for einops.repeat "einops>=0.3", - # fiona 1.8.21+ required for Python 3.10 wheels - "fiona>=1.8.21", + # fiona 1.8.22+ required for Python 3.11 wheels + "fiona>=1.8.22", # kornia 0.7.4+ required for AugmentationSequential support for unknown keys "kornia>=0.7.4", # lightly 1.4.5+ required for LARS optimizer @@ -51,53 +50,53 @@ dependencies = [ # lightning 2.3 contains known bugs related to YAML parsing # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*,!=2.5.0", - # matplotlib 3.5+ required for Python 3.10 wheels - "matplotlib>=3.5", - # numpy 1.21.2+ required by Python 3.10 wheels - "numpy>=1.21.2", - # pandas 1.3.3+ required for Python 3.10 wheels - "pandas>=1.3.3", - # pillow 8.4+ required for Python 3.10 wheels - "pillow>=8.4", - # pyproj 3.3+ required for Python 3.10 wheels - "pyproj>=3.3", - # rasterio 1.3+ required for Python 3.10 wheels + # matplotlib 3.6+ required for Python 3.11 wheels + "matplotlib>=3.6", + # numpy 1.23.2+ required by Python 3.11 wheels + "numpy>=1.23.2", + # pandas 1.5+ required for Python 3.11 wheels + "pandas>=1.5", + # pillow 9.2+ required for Python 3.11 wheels + "pillow>=9.2", + # pyproj 3.4+ required for Python 3.11 wheels + "pyproj>=3.4", + # rasterio 1.3.3+ required for Python 3.11 wheels # rasterio 1.4.0-1.4.2 lack support for merging WarpedVRT objects # https://github.com/rasterio/rasterio/issues/3196 - "rasterio>=1.3,!=1.4.0,!=1.4.1,!=1.4.2", - # rtree 1+ required for Python 3.10 wheels - "rtree>=1", + "rasterio>=1.3.3,!=1.4.0,!=1.4.1,!=1.4.2", + # rtree 1.0.1+ required for Python 3.11 wheels + "rtree>=1.0.1", # segmentation-models-pytorch 0.2+ required for smp.losses module "segmentation-models-pytorch>=0.2", - # shapely 1.8+ required for Python 3.10 wheels - "shapely>=1.8", + # shapely 1.8.5+ required for Python 3.11 wheels + "shapely>=1.8.5", # timm 0.4.12 required by segmentation-models-pytorch "timm>=0.4.12", - # torch 1.13+ required by torchvision - "torch>=1.13", + # torch 2+ required for Python 3.11 wheels + "torch>=2", # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics "torchmetrics>=0.10", - # torchvision 0.14+ required for torchvision.models.swin_v2_b - "torchvision>=0.14", + # torchvision 0.15.1+ required for Python 3.11 wheels + "torchvision>=0.15.1", ] dynamic = ["version"] [project.optional-dependencies] datasets = [ - # h5py 3.6+ required for Python 3.10 wheels - "h5py>=3.6", + # h5py 3.8+ required for Python 3.11 wheels + "h5py>=3.8", # laspy 2+ required for laspy.read "laspy>=2", - # opencv-python 4.5.4+ required for Python 3.10 wheels - "opencv-python>=4.5.4", + # opencv-python 4.5.5+ required for Python 3.11 wheels + "opencv-python>=4.5.5", # pandas 2+ required for parquet extra "pandas[parquet]>=2", - # pycocotools 2.0.7+ required for wheels + # pycocotools 2.0.7+ required for Python 3.11 wheels "pycocotools>=2.0.7", - # scikit-image 0.19+ required for Python 3.10 wheels - "scikit-image>=0.19", - # scipy 1.7.2+ required for Python 3.10 wheels - "scipy>=1.7.2", + # scikit-image 0.20+ required for Python 3.11 wheels + "scikit-image>=0.20", + # scipy 1.9.2+ required for Python 3.11 wheels + "scipy>=1.9.2", ] docs = [ # ipywidgets 7+ required by nbsphinx diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index e1058142ac6..755abf6f685 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -3,32 +3,32 @@ setuptools==61.0.0 # install einops==0.3.0 -fiona==1.8.21 +fiona==1.8.22 kornia==0.7.4 lightly==1.4.5 lightning[pytorch-extra]==2.0.0 -matplotlib==3.5.0 -numpy==1.21.2 -pandas==1.3.3 -pillow==8.4.0 -pyproj==3.3.0 -rasterio==1.3.0.post1 -rtree==1.0.0 +matplotlib==3.6.0 +numpy==1.23.2 +pandas==1.5.0 +pillow==9.2.0 +pyproj==3.4.0 +rasterio==1.3.3 +rtree==1.0.1 segmentation-models-pytorch==0.2.0 -shapely==1.8.0 +shapely==1.8.5 timm==0.4.12 -torch==1.13.0 +torch==2.0.0 torchmetrics==0.10.0 -torchvision==0.14.0 +torchvision==0.15.1 # datasets -h5py==3.6.0 +h5py==3.8.0 laspy==2.0.0 -opencv-python==4.5.4.58 +opencv-python==4.5.5.64 pycocotools==2.0.7 pyarrow==15.0.0 # Remove when we upgrade min version of pandas to `pandas[parquet]>=2` -scikit-image==0.19.0 -scipy==1.7.2 +scikit-image==0.20.0 +scipy==1.9.2 # tests pytest==7.3.0 diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index 12d20c0fd76..a2348f793a7 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -12,7 +12,7 @@ from torchgeo.datasets import ADVANCE, DatasetNotFoundError -pytest.importorskip('scipy', minversion='1.7.2') +pytest.importorskip('scipy', minversion='1.9.2') class TestADVANCE: diff --git a/tests/datasets/test_cabuar.py b/tests/datasets/test_cabuar.py index 967f43ee4d3..89eb3a8b33f 100644 --- a/tests/datasets/test_cabuar.py +++ b/tests/datasets/test_cabuar.py @@ -14,7 +14,7 @@ from torchgeo.datasets import CaBuAr, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestCaBuAr: diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py index fed0aed5087..cdd76709414 100644 --- a/tests/datasets/test_chabud.py +++ b/tests/datasets/test_chabud.py @@ -13,7 +13,7 @@ from torchgeo.datasets import ChaBuD, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestChaBuD: diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py index 3d77ac2b5fe..a8e4100d4a6 100644 --- a/tests/datasets/test_cropharvest.py +++ b/tests/datasets/test_cropharvest.py @@ -13,7 +13,7 @@ from torchgeo.datasets import CropHarvest, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestCropHarvest: diff --git a/tests/datasets/test_digital_typhoon.py b/tests/datasets/test_digital_typhoon.py index c3df283ec35..b105145c021 100644 --- a/tests/datasets/test_digital_typhoon.py +++ b/tests/datasets/test_digital_typhoon.py @@ -14,7 +14,7 @@ from torchgeo.datasets import DatasetNotFoundError, DigitalTyphoon -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestDigitalTyphoon: diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index cda68604599..b64b1d5e090 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -72,7 +72,7 @@ def test_plot(self, dataset: LandCoverAIGeo) -> None: class TestLandCoverAI: - pytest.importorskip('cv2', minversion='4.5.4') + pytest.importorskip('cv2', minversion='4.5.5') @pytest.fixture( params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test']) diff --git a/tests/datasets/test_mmearth.py b/tests/datasets/test_mmearth.py index c25c2a1dece..12117080e0e 100644 --- a/tests/datasets/test_mmearth.py +++ b/tests/datasets/test_mmearth.py @@ -12,7 +12,7 @@ from torchgeo.datasets import DatasetNotFoundError, MMEarth -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') data_dir_dict = { 'MMEarth': os.path.join('tests', 'data', 'mmearth', 'data_1M_v001'), diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index fbb6ea29234..ade1b124866 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -13,7 +13,7 @@ from torchgeo.datasets import DatasetNotFoundError, QuakeSet -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestQuakeSet: diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index 68f4e889df8..dc4ca4fedfc 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -15,7 +15,7 @@ from torchgeo.datasets import SKIPPD, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestSKIPPD: diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index bc88662c16a..11b608b3ec4 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -13,7 +13,7 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestSo2Sat: @@ -43,9 +43,7 @@ def test_len(self, dataset: So2Sat) -> None: assert len(dataset) == 2 def test_out_of_bounds(self, dataset: So2Sat) -> None: - # h5py at version 2.10.0 raises a ValueError instead of an IndexError so we - # check for both here - with pytest.raises((IndexError, ValueError)): + with pytest.raises(IndexError): dataset[2] def test_invalid_split(self) -> None: diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index aa0920d69e7..2e342d6846e 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -72,7 +72,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None: VHR10(tmp_path) def test_plot(self, dataset: VHR10) -> None: - pytest.importorskip('skimage', minversion='0.19') + pytest.importorskip('skimage', minversion='0.20') x = dataset[1].copy() dataset.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 6d4cdc8844c..00b176c48d1 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -12,7 +12,7 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop -pytest.importorskip('h5py', minversion='3.6') +pytest.importorskip('h5py', minversion='3.8') class TestZueriCrop: From f83597afc4b3308f901f46136d6db9464d1f35f4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 13:49:16 +0100 Subject: [PATCH 2/4] Increase coverage, simplify type hints --- experiments/ssl4eo/class_imbalance.py | 2 +- experiments/ssl4eo/compute_dataset_statistics.py | 2 +- experiments/ssl4eo/download_ssl4eo.py | 10 ++++------ experiments/ssl4eo/landsat/chip_landsat_benchmark.py | 2 +- torchgeo/datasets/bigearthnet.py | 2 +- torchgeo/datasets/cyclone.py | 6 +----- torchgeo/datasets/enviroatlas.py | 2 +- torchgeo/datasets/eurocrops.py | 4 +--- torchgeo/datasets/mapinwild.py | 2 +- torchgeo/datasets/mmearth.py | 6 +++--- torchgeo/datasets/oscd.py | 2 +- torchgeo/datasets/satlas.py | 7 +------ torchgeo/datasets/seco.py | 6 +----- torchgeo/datasets/utils.py | 2 +- torchgeo/models/rcf.py | 4 ++-- torchgeo/trainers/moco.py | 6 +----- 16 files changed, 22 insertions(+), 43 deletions(-) diff --git a/experiments/ssl4eo/class_imbalance.py b/experiments/ssl4eo/class_imbalance.py index faa280432b7..8f77399aba9 100755 --- a/experiments/ssl4eo/class_imbalance.py +++ b/experiments/ssl4eo/class_imbalance.py @@ -35,7 +35,7 @@ parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() - def class_counts(path: str) -> 'np.typing.NDArray[np.float64]': + def class_counts(path: str) -> np.typing.NDArray[np.float64]: """Calculate the number of values in each class. Args: diff --git a/experiments/ssl4eo/compute_dataset_statistics.py b/experiments/ssl4eo/compute_dataset_statistics.py index 05ff623f1e0..8321f23c375 100755 --- a/experiments/ssl4eo/compute_dataset_statistics.py +++ b/experiments/ssl4eo/compute_dataset_statistics.py @@ -19,7 +19,7 @@ parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() - def compute(path: str) -> tuple['np.typing.NDArray[np.float32]', int]: + def compute(path: str) -> tuple[np.typing.NDArray[np.float32], int]: """Compute the min, max, mean, and std dev of a single image. Args: diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py index 283e6e39c82..d21c38ffcbc 100755 --- a/experiments/ssl4eo/download_ssl4eo.py +++ b/experiments/ssl4eo/download_ssl4eo.py @@ -131,8 +131,8 @@ def filter_collection( def center_crop( - img: 'np.typing.NDArray[np.float32]', out_size: int -) -> 'np.typing.NDArray[np.float32]': + img: np.typing.NDArray[np.float32], out_size: int +) -> np.typing.NDArray[np.float32]: image_height, image_width = img.shape[:2] crop_height = crop_width = out_size pad_height = max(crop_height - image_height, 0) @@ -253,9 +253,7 @@ def get_random_patches_match( def save_geotiff( - img: 'np.typing.NDArray[np.float32]', - coords: list[tuple[float, float]], - filename: str, + img: np.typing.NDArray[np.float32], coords: list[tuple[float, float]], filename: str ) -> None: height, width, channels = img.shape xres = (coords[1][0] - coords[0][0]) / width @@ -278,7 +276,7 @@ def save_geotiff( def save_patch( - raster: dict[int, 'np.typing.NDArray[np.float32]'], + raster: dict[int, np.typing.NDArray[np.float32]], coords: list[tuple[float, float]], metadata: dict[str, Any], bands: list[str], diff --git a/experiments/ssl4eo/landsat/chip_landsat_benchmark.py b/experiments/ssl4eo/landsat/chip_landsat_benchmark.py index b02015201be..67abca36ed8 100755 --- a/experiments/ssl4eo/landsat/chip_landsat_benchmark.py +++ b/experiments/ssl4eo/landsat/chip_landsat_benchmark.py @@ -17,7 +17,7 @@ def retrieve_mask_chip( img_src: DatasetReader, mask_src: DatasetReader -) -> 'np.typing.NDArray[np.uint8]': +) -> np.typing.NDArray[np.uint8]: """Retrieve the mask for a given landsat image. Args: diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 8900bc9c991..d1b40c5671c 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -509,7 +509,7 @@ def _extract(self, filepath: Path) -> None: extract_archive(filepath) def _onehot_labels_to_names( - self, label_mask: 'np.typing.NDArray[np.bool_]' + self, label_mask: np.typing.NDArray[np.bool_] ) -> list[str]: """Gets a list of class names given a label mask. diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 2a21832703a..76b1097f009 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -131,11 +131,7 @@ def _load_image(self, image_id: str) -> Tensor: filename = os.path.join(self.root, self.split, f'{image_id}.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: - # Moved in PIL 9.1.0 - try: - resample = Image.Resampling.BILINEAR - except AttributeError: - resample = Image.BILINEAR # type: ignore[attr-defined] + resample = Image.Resampling.BILINEAR img = img.resize(size=(self.size, self.size), resample=resample) array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index b8af4aef70e..71acbe44103 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -124,7 +124,7 @@ class EnviroAtlas(GeoDataset): # used to convert the 10 high-res classes labeled as [0, 10, 20, 30, 40, 52, 70, 80, # 82, 91, 92] to sequential labels [0, ..., 10] - raw_enviroatlas_to_idx_map: 'np.typing.NDArray[np.uint8]' = np.array( + raw_enviroatlas_to_idx_map: np.typing.NDArray[np.uint8] = np.array( [ 0, 0, diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index 5f438143c87..c49a9714e9a 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -244,9 +244,7 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap( - arr: 'np.typing.NDArray[Any]', - ) -> 'np.typing.NDArray[np.float64]': + def apply_cmap(arr: np.typing.NDArray[Any]) -> np.typing.NDArray[np.float64]: # Color 0 as black, while applying default color map for the class indices. cmap = plt.get_cmap('viridis') im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map)) diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index cd294014318..b771e1d0321 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -314,7 +314,7 @@ def _merge_parts(self, modality: str) -> None: def _convert_to_color( self, arr_2d: Tensor, cmap: dict[int, tuple[int, int, int]] - ) -> 'np.typing.NDArray[np.uint8]': + ) -> np.typing.NDArray[np.uint8]: """Numeric labels to RGB-color encoding. Args: diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py index b940537d8b4..8c1fa36638a 100644 --- a/torchgeo/datasets/mmearth.py +++ b/torchgeo/datasets/mmearth.py @@ -493,7 +493,7 @@ def _select_indices_for_modality( def _preprocess_modality( self, - data: 'np.typing.NDArray[Any]', + data: np.typing.NDArray[Any], modality: str, tile_info: dict[str, Any], bands: list[str], @@ -575,8 +575,8 @@ def _preprocess_modality( return tensor def _normalize_modality( - self, data: 'np.typing.NDArray[Any]', modality: str, bands: list[str] - ) -> 'np.typing.NDArray[np.float64]': + self, data: np.typing.NDArray[Any], modality: str, bands: list[str] + ) -> np.typing.NDArray[np.float64]: """Normalize a single modality. Args: diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 28f7714a7c6..d9dc78798cf 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -309,7 +309,7 @@ def plot( except ValueError as e: raise RGBBandsMissingError() from e - def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': + def get_masked(img: Tensor) -> np.typing.NDArray[np.uint8]: rgb_img = img[rgb_indices].float().numpy() per02 = np.percentile(rgb_img, 2) per98 = np.percentile(rgb_img, 98) diff --git a/torchgeo/datasets/satlas.py b/torchgeo/datasets/satlas.py index 2f6e79c6c81..8c1cff85800 100644 --- a/torchgeo/datasets/satlas.py +++ b/torchgeo/datasets/satlas.py @@ -639,12 +639,6 @@ def _load_image( row: Web Mercator row. directories: Directories that may contain the image. """ - # Moved in PIL 9.1.0 - try: - resample = Image.Resampling.BILINEAR - except AttributeError: - resample = Image.BILINEAR # type: ignore[attr-defined] - # Find directories that match image product good_directories: list[str] = [] for directory in directories: @@ -659,6 +653,7 @@ def _load_image( sample[f'time_{image}'] = torch.tensor(time) # Load all bands + resample = Image.Resampling.BILINEAR channels = [] for band in self.bands[image]: path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png') diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index c67fecb9c8e..a16c0fb0877 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -171,11 +171,7 @@ def _load_patch(self, root: Path, subdir: Path) -> Tensor: # slowdown here from converting to/from a PIL Image just to resize. # https://gist.github.com/calebrob6/748045ac8d844154067b2eefa47de92f pil_image = Image.fromarray(band_data) - # Moved in PIL 9.1.0 - try: - resample = Image.Resampling.BILINEAR - except AttributeError: - resample = Image.BILINEAR # type: ignore[attr-defined] + resample = Image.Resampling.BILINEAR band_data = np.array( pil_image.resize((264, 264), resample=resample) ) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 7ddbe08e597..4a5336fa6d3 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks( image=image.byte(), masks=class_masks, alpha=alpha, colors=colors ) img = img.permute((1, 2, 0)).numpy().astype(np.uint8) - return cast('np.typing.NDArray[np.uint8]', img) + return cast(np.typing.NDArray[np.uint8], img) def rgb_to_mask( diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index cc330a0192c..171ad653070 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -123,10 +123,10 @@ def __init__( def _normalize( self, - patches: 'np.typing.NDArray[np.float32]', + patches: np.typing.NDArray[np.float32], min_divisor: float = 1e-8, zca_bias: float = 0.001, - ) -> 'np.typing.NDArray[np.float32]': + ) -> np.typing.NDArray[np.float32]: """Does ZCA whitening on a set of input patches. Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120 diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 2e7e6907e37..4105211d893 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import ( CosineAnnealingLR, LinearLR, + LRScheduler, MultiStepLR, SequentialLR, ) @@ -34,11 +35,6 @@ from . import utils from .base import BaseTask -try: - from torch.optim.lr_scheduler import LRScheduler -except ImportError: - from torch.optim.lr_scheduler import _LRScheduler as LRScheduler - def moco_augmentations( version: int, size: int, weights: Tensor From 460d2b7d010da534b6f7df290a543fcef3f934db Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 18:05:46 +0100 Subject: [PATCH 3/4] Revert "Increase coverage, simplify type hints" This reverts commit f83597afc4b3308f901f46136d6db9464d1f35f4. --- experiments/ssl4eo/class_imbalance.py | 2 +- experiments/ssl4eo/compute_dataset_statistics.py | 2 +- experiments/ssl4eo/download_ssl4eo.py | 10 ++++++---- experiments/ssl4eo/landsat/chip_landsat_benchmark.py | 2 +- torchgeo/datasets/bigearthnet.py | 2 +- torchgeo/datasets/cyclone.py | 6 +++++- torchgeo/datasets/enviroatlas.py | 2 +- torchgeo/datasets/eurocrops.py | 4 +++- torchgeo/datasets/mapinwild.py | 2 +- torchgeo/datasets/mmearth.py | 6 +++--- torchgeo/datasets/oscd.py | 2 +- torchgeo/datasets/satlas.py | 7 ++++++- torchgeo/datasets/seco.py | 6 +++++- torchgeo/datasets/utils.py | 2 +- torchgeo/models/rcf.py | 4 ++-- torchgeo/trainers/moco.py | 6 +++++- 16 files changed, 43 insertions(+), 22 deletions(-) diff --git a/experiments/ssl4eo/class_imbalance.py b/experiments/ssl4eo/class_imbalance.py index 8f77399aba9..faa280432b7 100755 --- a/experiments/ssl4eo/class_imbalance.py +++ b/experiments/ssl4eo/class_imbalance.py @@ -35,7 +35,7 @@ parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() - def class_counts(path: str) -> np.typing.NDArray[np.float64]: + def class_counts(path: str) -> 'np.typing.NDArray[np.float64]': """Calculate the number of values in each class. Args: diff --git a/experiments/ssl4eo/compute_dataset_statistics.py b/experiments/ssl4eo/compute_dataset_statistics.py index 8321f23c375..05ff623f1e0 100755 --- a/experiments/ssl4eo/compute_dataset_statistics.py +++ b/experiments/ssl4eo/compute_dataset_statistics.py @@ -19,7 +19,7 @@ parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() - def compute(path: str) -> tuple[np.typing.NDArray[np.float32], int]: + def compute(path: str) -> tuple['np.typing.NDArray[np.float32]', int]: """Compute the min, max, mean, and std dev of a single image. Args: diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py index d21c38ffcbc..283e6e39c82 100755 --- a/experiments/ssl4eo/download_ssl4eo.py +++ b/experiments/ssl4eo/download_ssl4eo.py @@ -131,8 +131,8 @@ def filter_collection( def center_crop( - img: np.typing.NDArray[np.float32], out_size: int -) -> np.typing.NDArray[np.float32]: + img: 'np.typing.NDArray[np.float32]', out_size: int +) -> 'np.typing.NDArray[np.float32]': image_height, image_width = img.shape[:2] crop_height = crop_width = out_size pad_height = max(crop_height - image_height, 0) @@ -253,7 +253,9 @@ def get_random_patches_match( def save_geotiff( - img: np.typing.NDArray[np.float32], coords: list[tuple[float, float]], filename: str + img: 'np.typing.NDArray[np.float32]', + coords: list[tuple[float, float]], + filename: str, ) -> None: height, width, channels = img.shape xres = (coords[1][0] - coords[0][0]) / width @@ -276,7 +278,7 @@ def save_geotiff( def save_patch( - raster: dict[int, np.typing.NDArray[np.float32]], + raster: dict[int, 'np.typing.NDArray[np.float32]'], coords: list[tuple[float, float]], metadata: dict[str, Any], bands: list[str], diff --git a/experiments/ssl4eo/landsat/chip_landsat_benchmark.py b/experiments/ssl4eo/landsat/chip_landsat_benchmark.py index 67abca36ed8..b02015201be 100755 --- a/experiments/ssl4eo/landsat/chip_landsat_benchmark.py +++ b/experiments/ssl4eo/landsat/chip_landsat_benchmark.py @@ -17,7 +17,7 @@ def retrieve_mask_chip( img_src: DatasetReader, mask_src: DatasetReader -) -> np.typing.NDArray[np.uint8]: +) -> 'np.typing.NDArray[np.uint8]': """Retrieve the mask for a given landsat image. Args: diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index d1b40c5671c..8900bc9c991 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -509,7 +509,7 @@ def _extract(self, filepath: Path) -> None: extract_archive(filepath) def _onehot_labels_to_names( - self, label_mask: np.typing.NDArray[np.bool_] + self, label_mask: 'np.typing.NDArray[np.bool_]' ) -> list[str]: """Gets a list of class names given a label mask. diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 76b1097f009..2a21832703a 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -131,7 +131,11 @@ def _load_image(self, image_id: str) -> Tensor: filename = os.path.join(self.root, self.split, f'{image_id}.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: - resample = Image.Resampling.BILINEAR + # Moved in PIL 9.1.0 + try: + resample = Image.Resampling.BILINEAR + except AttributeError: + resample = Image.BILINEAR # type: ignore[attr-defined] img = img.resize(size=(self.size, self.size), resample=resample) array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 71acbe44103..b8af4aef70e 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -124,7 +124,7 @@ class EnviroAtlas(GeoDataset): # used to convert the 10 high-res classes labeled as [0, 10, 20, 30, 40, 52, 70, 80, # 82, 91, 92] to sequential labels [0, ..., 10] - raw_enviroatlas_to_idx_map: np.typing.NDArray[np.uint8] = np.array( + raw_enviroatlas_to_idx_map: 'np.typing.NDArray[np.uint8]' = np.array( [ 0, 0, diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index c49a9714e9a..5f438143c87 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -244,7 +244,9 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap(arr: np.typing.NDArray[Any]) -> np.typing.NDArray[np.float64]: + def apply_cmap( + arr: 'np.typing.NDArray[Any]', + ) -> 'np.typing.NDArray[np.float64]': # Color 0 as black, while applying default color map for the class indices. cmap = plt.get_cmap('viridis') im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map)) diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index b771e1d0321..cd294014318 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -314,7 +314,7 @@ def _merge_parts(self, modality: str) -> None: def _convert_to_color( self, arr_2d: Tensor, cmap: dict[int, tuple[int, int, int]] - ) -> np.typing.NDArray[np.uint8]: + ) -> 'np.typing.NDArray[np.uint8]': """Numeric labels to RGB-color encoding. Args: diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py index 8c1fa36638a..b940537d8b4 100644 --- a/torchgeo/datasets/mmearth.py +++ b/torchgeo/datasets/mmearth.py @@ -493,7 +493,7 @@ def _select_indices_for_modality( def _preprocess_modality( self, - data: np.typing.NDArray[Any], + data: 'np.typing.NDArray[Any]', modality: str, tile_info: dict[str, Any], bands: list[str], @@ -575,8 +575,8 @@ def _preprocess_modality( return tensor def _normalize_modality( - self, data: np.typing.NDArray[Any], modality: str, bands: list[str] - ) -> np.typing.NDArray[np.float64]: + self, data: 'np.typing.NDArray[Any]', modality: str, bands: list[str] + ) -> 'np.typing.NDArray[np.float64]': """Normalize a single modality. Args: diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index d9dc78798cf..28f7714a7c6 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -309,7 +309,7 @@ def plot( except ValueError as e: raise RGBBandsMissingError() from e - def get_masked(img: Tensor) -> np.typing.NDArray[np.uint8]: + def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': rgb_img = img[rgb_indices].float().numpy() per02 = np.percentile(rgb_img, 2) per98 = np.percentile(rgb_img, 98) diff --git a/torchgeo/datasets/satlas.py b/torchgeo/datasets/satlas.py index 8c1cff85800..2f6e79c6c81 100644 --- a/torchgeo/datasets/satlas.py +++ b/torchgeo/datasets/satlas.py @@ -639,6 +639,12 @@ def _load_image( row: Web Mercator row. directories: Directories that may contain the image. """ + # Moved in PIL 9.1.0 + try: + resample = Image.Resampling.BILINEAR + except AttributeError: + resample = Image.BILINEAR # type: ignore[attr-defined] + # Find directories that match image product good_directories: list[str] = [] for directory in directories: @@ -653,7 +659,6 @@ def _load_image( sample[f'time_{image}'] = torch.tensor(time) # Load all bands - resample = Image.Resampling.BILINEAR channels = [] for band in self.bands[image]: path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png') diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index a16c0fb0877..c67fecb9c8e 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -171,7 +171,11 @@ def _load_patch(self, root: Path, subdir: Path) -> Tensor: # slowdown here from converting to/from a PIL Image just to resize. # https://gist.github.com/calebrob6/748045ac8d844154067b2eefa47de92f pil_image = Image.fromarray(band_data) - resample = Image.Resampling.BILINEAR + # Moved in PIL 9.1.0 + try: + resample = Image.Resampling.BILINEAR + except AttributeError: + resample = Image.BILINEAR # type: ignore[attr-defined] band_data = np.array( pil_image.resize((264, 264), resample=resample) ) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 4a5336fa6d3..7ddbe08e597 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks( image=image.byte(), masks=class_masks, alpha=alpha, colors=colors ) img = img.permute((1, 2, 0)).numpy().astype(np.uint8) - return cast(np.typing.NDArray[np.uint8], img) + return cast('np.typing.NDArray[np.uint8]', img) def rgb_to_mask( diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index 171ad653070..cc330a0192c 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -123,10 +123,10 @@ def __init__( def _normalize( self, - patches: np.typing.NDArray[np.float32], + patches: 'np.typing.NDArray[np.float32]', min_divisor: float = 1e-8, zca_bias: float = 0.001, - ) -> np.typing.NDArray[np.float32]: + ) -> 'np.typing.NDArray[np.float32]': """Does ZCA whitening on a set of input patches. Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120 diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 4105211d893..2e7e6907e37 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -23,7 +23,6 @@ from torch.optim.lr_scheduler import ( CosineAnnealingLR, LinearLR, - LRScheduler, MultiStepLR, SequentialLR, ) @@ -35,6 +34,11 @@ from . import utils from .base import BaseTask +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + def moco_augmentations( version: int, size: int, weights: Tensor From 2acce5ebca4d3f75ceccdaccf3f4658145639cca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 18:08:35 +0100 Subject: [PATCH 4/4] Increase coverage --- torchgeo/datasets/cyclone.py | 6 +----- torchgeo/datasets/satlas.py | 7 +------ torchgeo/datasets/seco.py | 6 +----- torchgeo/trainers/moco.py | 6 +----- 4 files changed, 4 insertions(+), 21 deletions(-) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 2a21832703a..76b1097f009 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -131,11 +131,7 @@ def _load_image(self, image_id: str) -> Tensor: filename = os.path.join(self.root, self.split, f'{image_id}.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: - # Moved in PIL 9.1.0 - try: - resample = Image.Resampling.BILINEAR - except AttributeError: - resample = Image.BILINEAR # type: ignore[attr-defined] + resample = Image.Resampling.BILINEAR img = img.resize(size=(self.size, self.size), resample=resample) array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) diff --git a/torchgeo/datasets/satlas.py b/torchgeo/datasets/satlas.py index 2f6e79c6c81..8c1cff85800 100644 --- a/torchgeo/datasets/satlas.py +++ b/torchgeo/datasets/satlas.py @@ -639,12 +639,6 @@ def _load_image( row: Web Mercator row. directories: Directories that may contain the image. """ - # Moved in PIL 9.1.0 - try: - resample = Image.Resampling.BILINEAR - except AttributeError: - resample = Image.BILINEAR # type: ignore[attr-defined] - # Find directories that match image product good_directories: list[str] = [] for directory in directories: @@ -659,6 +653,7 @@ def _load_image( sample[f'time_{image}'] = torch.tensor(time) # Load all bands + resample = Image.Resampling.BILINEAR channels = [] for band in self.bands[image]: path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png') diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index c67fecb9c8e..a16c0fb0877 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -171,11 +171,7 @@ def _load_patch(self, root: Path, subdir: Path) -> Tensor: # slowdown here from converting to/from a PIL Image just to resize. # https://gist.github.com/calebrob6/748045ac8d844154067b2eefa47de92f pil_image = Image.fromarray(band_data) - # Moved in PIL 9.1.0 - try: - resample = Image.Resampling.BILINEAR - except AttributeError: - resample = Image.BILINEAR # type: ignore[attr-defined] + resample = Image.Resampling.BILINEAR band_data = np.array( pil_image.resize((264, 264), resample=resample) ) diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 2e7e6907e37..4105211d893 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import ( CosineAnnealingLR, LinearLR, + LRScheduler, MultiStepLR, SequentialLR, ) @@ -34,11 +35,6 @@ from . import utils from .base import BaseTask -try: - from torch.optim.lr_scheduler import LRScheduler -except ImportError: - from torch.optim.lr_scheduler import _LRScheduler as LRScheduler - def moco_augmentations( version: int, size: int, weights: Tensor