From cee5d2ad40711c2f66c2b158eb0ebd3d042f01af Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 1 Jul 2024 15:57:22 +0400 Subject: [PATCH 01/21] Switch to kornia AugmentationSequential in datamodules --- tests/datamodules/test_geo.py | 4 +-- torchgeo/datamodules/agrifieldnet.py | 5 ++-- torchgeo/datamodules/chesapeake.py | 5 ++-- torchgeo/datamodules/deepglobelandcover.py | 5 ++-- torchgeo/datamodules/fire_risk.py | 5 ++-- torchgeo/datamodules/geo.py | 5 ++-- torchgeo/datamodules/gid15.py | 9 +++--- torchgeo/datamodules/inria.py | 15 +++++----- torchgeo/datamodules/l7irish.py | 5 ++-- torchgeo/datamodules/l8biome.py | 5 ++-- torchgeo/datamodules/landcoverai.py | 13 ++++----- torchgeo/datamodules/levircd.py | 29 ++++++++----------- torchgeo/datamodules/naip.py | 5 ++-- torchgeo/datamodules/oscd.py | 5 ++-- torchgeo/datamodules/potsdam.py | 5 ++-- torchgeo/datamodules/resisc45.py | 5 ++-- torchgeo/datamodules/seco.py | 5 ++-- torchgeo/datamodules/sentinel2_cdl.py | 9 +++--- torchgeo/datamodules/sentinel2_nccm.py | 9 +++--- .../sentinel2_south_america_soybean.py | 9 +++--- torchgeo/datamodules/spacenet.py | 4 +-- torchgeo/datamodules/ssl4eo_benchmark.py | 13 ++++----- torchgeo/datamodules/ucmerced.py | 5 ++-- torchgeo/datamodules/vaihingen.py | 5 ++-- 24 files changed, 79 insertions(+), 105 deletions(-) diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 8e5fd13d292..81ad845d104 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -31,7 +31,7 @@ def __init__( self.res = 1 def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - image = torch.arange(3 * 2 * 2).view(3, 2, 2) + image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2) return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query} def plot(self, *args: Any, **kwargs: Any) -> Figure: @@ -68,7 +68,7 @@ def __init__( self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: - return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)} + return {"image": torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} def __len__(self) -> int: return self.length diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index c5b92b6b01a..0c9cbe25d98 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -12,7 +12,6 @@ from ..datasets import AgriFieldNet, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,12 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 37a0d32edbd..2b480352d72 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -13,7 +13,6 @@ from ..datasets import ChesapeakeCVPR from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -122,8 +121,8 @@ def __init__( else: self.layers = ['naip-new', 'lc'] - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 80ea6052cb7..3112290dea9 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -11,7 +11,6 @@ from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -46,10 +45,10 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index 1a0d6c7c047..cfa7452a6d7 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import FireRisk -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -30,7 +29,7 @@ def __init__( :class:`~torchgeo.datasets.FireRisk`. """ super().__init__(FireRisk, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -38,7 +37,7 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e8e3aedd194..2dd94bc7057 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -20,7 +20,6 @@ GridGeoSampler, RandomBatchGeoSampler, ) -from ..transforms import AugmentationSequential from .utils import MisconfigurationException @@ -70,8 +69,8 @@ def __init__( # Data augmentation Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] - self.aug: Transform = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image'] + self.aug: Transform = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) self.train_aug: Transform | None = None self.val_aug: Transform | None = None diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index fc4e802c148..06985a73bb0 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -11,7 +11,6 @@ from ..datasets import GID15 from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,15 +47,15 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = self.val_aug = AugmentationSequential( + self.train_aug = self.val_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 39e8ede22c5..6fc233dd407 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -9,7 +9,6 @@ from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -40,26 +39,26 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.InriaAerialImageLabeling`. """ - super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs) + super().__init__(InriaAerialImageLabeling, batch_size, num_workers, **kwargs) self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index 35408feddbb..aca9693ea9a 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -12,7 +12,6 @@ from ..datasets import L7Irish, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,12 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index ddc802a5ce3..e94db30d392 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -12,7 +12,6 @@ from ..datasets import L8Biome, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,12 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 4cbbf1a5e2a..396613249a0 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import LandCoverAI, LandCoverAI100 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -31,17 +30,17 @@ def __init__( """ super().__init__(LandCoverAI, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) @@ -66,6 +65,6 @@ def __init__( """ super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 8488c2e58b3..3f0343010a9 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -11,7 +11,6 @@ from ..datasets import LEVIRCD, LEVIRCDPlus from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -43,18 +42,16 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.val_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) - self.test_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.test_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) @@ -91,18 +88,16 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.val_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) - self.test_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.test_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 1be0a64cff5..0236b848d5a 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -20,7 +20,6 @@ ChesapeakeWV, ) from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -62,8 +61,8 @@ def __init__( NAIP, batch_size, patch_size, length, num_workers, **self.naip_kwargs ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 87ae20cdf40..f30c47c76bb 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -11,7 +11,6 @@ from ..datasets import OSCD from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -85,10 +84,10 @@ def __init__( self.mean = torch.tensor([MEAN[b] for b in self.bands]) self.std = torch.tensor([STD[b] for b in self.bands]) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 8011382c769..452c5c47844 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -11,7 +11,6 @@ from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,10 +47,10 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index e88e139f481..d908bb840c1 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -9,7 +9,6 @@ import torch from ..datasets import RESISC45 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -36,7 +35,7 @@ def __init__( """ super().__init__(RESISC45, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -44,5 +43,5 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image'], + data_keys=None, ) diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index 1160f037366..1947ca56913 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -10,7 +10,6 @@ from einops import repeat from ..datasets import SeasonalContrastS2 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -49,11 +48,11 @@ def __init__( _mean = repeat(_mean, 'c -> (t c)', t=seasons) _std = repeat(_std, 'c -> (t c)', t=seasons) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), - data_keys=['image'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index 97c3d05392e..d0fad7caaac 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -13,7 +13,6 @@ from ..datasets import CDL, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -63,19 +62,19 @@ def __init__( CDL, batch_size, patch_size, length, num_workers, **self.cdl_kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 91b4f936fdc..6d306058dea 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -13,7 +13,6 @@ from ..datasets import NCCM, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -63,19 +62,19 @@ def __init__( NCCM, batch_size, patch_size, length, num_workers, **self.nccm_kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index e3363e857f5..a7723a2ca5f 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -14,7 +14,6 @@ from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -62,19 +61,19 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index e6b0b325823..a9769204235 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -124,7 +124,7 @@ def __init__( SpaceNet1, batch_size, num_workers, val_split_pct, test_split_pct, **kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), K.RandomRotation(p=0.5, degrees=90), @@ -134,7 +134,7 @@ def __init__( K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=['image', 'mask'], ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), data_keys=['image', 'mask'], diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index c9eb1d2e315..2b4aba14422 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -10,7 +10,6 @@ from ..datasets import SSL4EOLBenchmark from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -40,23 +39,23 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.val_aug = AugmentationSequential( + self.val_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=['image', 'mask'], + data_keys=None, ) - self.test_aug = AugmentationSequential( + self.test_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=['image', 'mask'], + data_keys=None, ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 59bb49444ee..34537cda365 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import UCMerced -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -31,8 +30,8 @@ def __init__( """ super().__init__(UCMerced, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size=256), - data_keys=['image'], + data_keys=None, ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 98bc4945e95..35eec7821ae 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -11,7 +11,6 @@ from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,10 +47,10 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: From 3a5c8b11f01982587c313558ce0ccdb6c138a6ec Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 1 Jul 2024 17:17:16 +0400 Subject: [PATCH 02/21] Fix for classification & regression tests --- torchgeo/datasets/cyclone.py | 6 ++++-- torchgeo/datasets/inria.py | 5 ++++- torchgeo/datasets/quakeset.py | 3 ++- torchgeo/datasets/skippd.py | 4 +++- torchgeo/datasets/sustainbench_crop_yield.py | 9 ++++++++- 5 files changed, 21 insertions(+), 6 deletions(-) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 2a21832703a..1ff65c86406 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -96,17 +96,19 @@ def __getitem__(self, index: int) -> dict[str, Any]: Returns: data, labels, field ids, and metadata at that index """ - sample = { + features = { 'relative_time': torch.tensor(self.features.iat[index, 2]), 'ocean': torch.tensor(self.features.iat[index, 3]), 'label': torch.tensor(self.labels.iat[index, 1]), } image_id = self.labels.iat[index, 0] - sample['image'] = self._load_image(image_id) + sample = {'image': self._load_image(image_id)} + sample['label'] = features['label'] if self.transforms is not None: sample = self.transforms(sample) + sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 3b2a4348a96..50208d012a3 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -217,7 +217,10 @@ def plot( show_predictions = 'prediction' in sample if show_mask: - mask = sample['mask'].numpy() + mask = sample['mask'] + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = mask.numpy() ncols += 1 if show_predictions: diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 811d79cff08..899b7b664d2 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -117,10 +117,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: label = torch.tensor(self.data[index]['label']) magnitude = torch.tensor(self.data[index]['magnitude']) - sample = {'image': image, 'label': label, 'magnitude': magnitude} + sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) + sample['magnitude'] = magnitude return sample diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 3accf32d2af..a67ee709f9e 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -144,10 +144,12 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: data and label at that index """ sample: dict[str, str | Tensor] = {'image': self._load_image(index)} - sample.update(self._load_features(index)) + features = self._load_features(index) + sample['label'] = features['label'] if self.transforms is not None: sample = self.transforms(sample) + sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index eec9be57ab3..002bf87565d 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -149,10 +149,17 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ sample: dict[str, Tensor] = {'image': self.images[index]} - sample.update(self.features[index]) + sample['label'] = self.features[index]['label'] if self.transforms is not None: sample = self.transforms(sample) + sample.update( + { + x: self.features[index][x] + for x in self.features[index] + if x != 'label' + } + ) return sample From d46fe01e848c53ea380455d5753e7648ea2691cc Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 1 Jul 2024 22:19:53 +0400 Subject: [PATCH 03/21] Fix for segmentation tests --- tests/datamodules/test_geo.py | 2 +- torchgeo/datamodules/spacenet.py | 4 ++-- torchgeo/trainers/segmentation.py | 8 ++++++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 81ad845d104..4e5431c684f 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -68,7 +68,7 @@ def __init__( self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: - return {"image": torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} + return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} def __len__(self) -> int: return self.length diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index a9769204235..0f7b4fb062c 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -132,12 +132,12 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image', 'mask'], + data_keys=None, ) self.predict_aug = AugmentationSequential( diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index f8e519fa493..5bdfceb2867 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn +from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection @@ -225,6 +226,9 @@ def training_step( Returns: The loss tensor. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -245,6 +249,8 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] From 25ae34108541059972232d54e7edd67c0837dc73 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 2 Jul 2024 22:09:48 +0400 Subject: [PATCH 04/21] Revert inria batch_size --- torchgeo/datamodules/inria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 6fc233dd407..7bd4d0ae165 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -39,7 +39,7 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.InriaAerialImageLabeling`. """ - super().__init__(InriaAerialImageLabeling, batch_size, num_workers, **kwargs) + super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs) self.patch_size = _to_tuple(patch_size) From d21b1934bf9b15c0433377a30058462cf23c3441 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sun, 14 Jul 2024 23:17:29 +0400 Subject: [PATCH 05/21] Set keepdim=True --- torchgeo/datamodules/agrifieldnet.py | 3 +++ torchgeo/datamodules/chesapeake.py | 4 +++- torchgeo/datamodules/deepglobelandcover.py | 3 +++ torchgeo/datamodules/fire_risk.py | 3 +++ torchgeo/datamodules/geo.py | 4 +++- torchgeo/datamodules/gid15.py | 6 ++++++ torchgeo/datamodules/inria.py | 8 ++++++++ torchgeo/datamodules/l7irish.py | 3 +++ torchgeo/datamodules/l8biome.py | 3 +++ torchgeo/datamodules/landcoverai.py | 11 ++++++++-- torchgeo/datamodules/levircd.py | 20 +++++++++++++++---- torchgeo/datamodules/naip.py | 4 +++- torchgeo/datamodules/oscd.py | 3 +++ torchgeo/datamodules/potsdam.py | 3 +++ torchgeo/datamodules/quakeset.py | 3 +++ torchgeo/datamodules/resisc45.py | 3 +++ torchgeo/datamodules/seco.py | 3 +++ torchgeo/datamodules/sentinel2_cdl.py | 7 ++++++- torchgeo/datamodules/sentinel2_eurocrops.py | 14 ++++++++----- torchgeo/datamodules/sentinel2_nccm.py | 7 ++++++- .../sentinel2_south_america_soybean.py | 7 ++++++- torchgeo/datamodules/southafricacroptype.py | 14 ++++++++----- torchgeo/datamodules/spacenet.py | 10 +++++++++- torchgeo/datamodules/ssl4eo_benchmark.py | 8 ++++++++ torchgeo/datamodules/ucmerced.py | 3 +++ torchgeo/datamodules/vaihingen.py | 3 +++ torchgeo/trainers/segmentation.py | 8 -------- 27 files changed, 137 insertions(+), 31 deletions(-) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index 0c9cbe25d98..d31eff85320 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -54,10 +54,13 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 2b480352d72..0243010b179 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -122,8 +122,10 @@ def __init__( self.layers = ['naip-new', 'lc'] self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 3112290dea9..920cc1644d2 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -49,7 +49,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index cfa7452a6d7..58855d26fae 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -38,7 +38,10 @@ def __init__( K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 2dd94bc7057..0f78fcf5a1e 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -70,8 +70,10 @@ def __init__( # Data augmentation Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] self.aug: Transform = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] self.train_aug: Transform | None = None self.val_aug: Transform | None = None self.test_aug: Transform | None = None diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 06985a73bb0..a7e48226a1a 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -51,13 +51,19 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.predict_aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 7bd4d0ae165..de369fe67ed 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -49,18 +49,26 @@ def __init__( K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + self.predict_aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index aca9693ea9a..0f3ec561990 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -54,10 +54,13 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index e94db30d392..2c2ab4c2152 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -54,10 +54,13 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 396613249a0..3c78183ad49 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -38,11 +38,15 @@ def __init__( K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] class LandCoverAI100DataModule(NonGeoDataModule): """LightningDataModule implementation for the LandCoverAI100 dataset. @@ -66,5 +70,8 @@ def __init__( super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 3f0343010a9..8a5bc9e8012 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -46,14 +46,20 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.val_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) self.test_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] + self.test_aug.keepdim = True # type: ignore[attr-defined] + class LEVIRCDPlusDataModule(NonGeoDataModule): """LightningDataModule implementation for the LEVIR-CD+ dataset. @@ -92,14 +98,20 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.val_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) self.test_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] + self.test_aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 0236b848d5a..019d672cc3d 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -62,8 +62,10 @@ def __init__( ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index f30c47c76bb..2c42e877c94 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -88,7 +88,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 452c5c47844..0526d5b6026 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -51,7 +51,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 1a3e19a5122..2316c23b4ae 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -39,4 +39,7 @@ def __init__( K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), data_keys=['image'], + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index d908bb840c1..0a046a5efbb 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -44,4 +44,7 @@ def __init__( K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index 1947ca56913..b0f59bfe127 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -53,7 +53,10 @@ def __init__( K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index d0fad7caaac..b84dd9e755c 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -68,15 +68,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 4e0893e4f8d..9eeea37d5cd 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -13,7 +13,6 @@ from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -64,21 +63,26 @@ def __init__( **self.eurocrops_kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 6d306058dea..4897c267c16 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -68,15 +68,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index a7723a2ca5f..b3963b40c6b 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -67,15 +67,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/southafricacroptype.py b/torchgeo/datamodules/southafricacroptype.py index 3f44bb61471..45c0534815f 100644 --- a/torchgeo/datamodules/southafricacroptype.py +++ b/torchgeo/datamodules/southafricacroptype.py @@ -12,7 +12,6 @@ from ..datasets import SouthAfricaCropType, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,21 +48,26 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 0f7b4fb062c..7fccf774279 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -133,19 +133,27 @@ def __init__( K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), data_keys=None, + keepdim=True, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), data_keys=['image'], ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.predict_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + + class SpaceNet6DataModule(SpaceNetBaseDataModule): """LightningDataModule implementation for the SpaceNet6 dataset. diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index 2b4aba14422..46b57eba0ce 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -45,6 +45,7 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, @@ -53,9 +54,16 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), data_keys=None, + keepdim=True, ) self.test_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), data_keys=None, + keepdim=True, ) + + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 34537cda365..abc30b29c37 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -34,4 +34,7 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), K.Resize(size=256), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 35eec7821ae..49414358cd1 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -51,7 +51,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 5bdfceb2867..f8e519fa493 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,7 +9,6 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn -from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection @@ -226,9 +225,6 @@ def training_step( Returns: The loss tensor. """ - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') - x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -249,8 +245,6 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -295,8 +289,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] From 9cd87a0f00057f243265f05e6d010fbce80087e9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 18 Jul 2024 14:35:49 +0200 Subject: [PATCH 06/21] Fix ssl4eo_benchmark --- torchgeo/datamodules/ssl4eo_benchmark.py | 2 +- torchgeo/datasets/ssl4eo_benchmark.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index 46b57eba0ce..f63e5861550 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -66,4 +66,4 @@ def __init__( # https://github.com/kornia/kornia/issues/2848 self.train_aug.keepdim = True # type: ignore[attr-defined] self.val_aug.keepdim = True # type: ignore[attr-defined] - self.val_aug.keepdim = True # type: ignore[attr-defined] + self.test_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 7bc4f828974..13c5a8474c4 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -324,7 +324,7 @@ def _load_mask(self, path: Path) -> Tensor: mask """ with rasterio.open(path) as src: - mask = torch.from_numpy(src.read()).long() + mask = torch.from_numpy(src.read(1)).long() mask = self.ordinal_map[mask] return mask From 76875ceac7d5646687ce273783941bbbedf73544 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 18 Jul 2024 14:45:38 +0200 Subject: [PATCH 07/21] Masks must not have a channel dimension --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d86bc58e6ed..26a035d427d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -554,7 +554,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if self.is_image: sample['image'] = data else: - sample['mask'] = data + sample['mask'] = data.squeeze(0) if self.transforms is not None: sample = self.transforms(sample) From aaf425de309d45c520ee80d2e7878ac235b77370 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 18 Jul 2024 15:16:25 +0200 Subject: [PATCH 08/21] Fix south africa dimensions --- torchgeo/datasets/south_africa_crop_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 3c4f7f895ec..a8643873c5b 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -223,7 +223,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: ) mask_filepaths.append(file_path) - mask = self._merge_files(mask_filepaths, query) + mask = self._merge_files(mask_filepaths, query).squeeze(0) sample = { 'crs': self.crs, From 5106559a975eff8e246ef2963603cf4e77cda7ca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 18 Jul 2024 15:18:12 +0200 Subject: [PATCH 09/21] Fix chesapeake dimensions --- torchgeo/datamodules/chesapeake.py | 4 +--- torchgeo/datasets/chesapeake.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 0243010b179..6683085e439 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -43,8 +43,6 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]: sample[key] = sample[key].float() sample[key] = self.aug(sample[key]) sample[key] = sample[key].to(dtype) - # Kornia adds batch dimension - sample[key] = rearrange(sample[key], '() c h w -> c h w') return sample @@ -93,7 +91,7 @@ def __init__( # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 3 - kwargs['transforms'] = _Transform(K.CenterCrop(patch_size)) + kwargs['transforms'] = _Transform(K.CenterCrop(patch_size, keepdim=True)) super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 7e521092aaf..af9b01f54dc 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -581,7 +581,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample['mask'] = np.concatenate(sample['mask'], axis=0) sample['image'] = torch.from_numpy(sample['image']).float() - sample['mask'] = torch.from_numpy(sample['mask']).long() + sample['mask'] = torch.from_numpy(sample['mask']).long().squeeze(0) if self.transforms is not None: sample = self.transforms(sample) From e8d2164a690d38fc608962c2e2fcbdbb99eb1813 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 18 Jul 2024 15:19:45 +0200 Subject: [PATCH 10/21] Remove unused import --- torchgeo/datamodules/chesapeake.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 6683085e439..b2e9848d4f5 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -8,7 +8,6 @@ import kornia.augmentation as K import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import Tensor from ..datasets import ChesapeakeCVPR From 51d09575518985f290a0f8cf16f1c7d52bd6d0b3 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 20 Jul 2024 17:14:39 +0400 Subject: [PATCH 11/21] mypy fixes & remove _Transform --- .pre-commit-config.yaml | 2 +- torchgeo/datamodules/agrifieldnet.py | 2 +- torchgeo/datamodules/chesapeake.py | 39 +++---------------- torchgeo/datamodules/deepglobelandcover.py | 2 +- torchgeo/datamodules/fire_risk.py | 2 +- torchgeo/datamodules/geo.py | 2 +- torchgeo/datamodules/gid15.py | 4 +- torchgeo/datamodules/inria.py | 6 +-- torchgeo/datamodules/l7irish.py | 2 +- torchgeo/datamodules/l8biome.py | 2 +- torchgeo/datamodules/levircd.py | 12 +++--- torchgeo/datamodules/naip.py | 2 +- torchgeo/datamodules/oscd.py | 2 +- torchgeo/datamodules/potsdam.py | 2 +- torchgeo/datamodules/quakeset.py | 5 +-- torchgeo/datamodules/resisc45.py | 2 +- torchgeo/datamodules/seco.py | 2 +- torchgeo/datamodules/sentinel2_cdl.py | 4 +- torchgeo/datamodules/sentinel2_eurocrops.py | 4 +- torchgeo/datamodules/sentinel2_nccm.py | 4 +- .../sentinel2_south_america_soybean.py | 4 +- torchgeo/datamodules/southafricacroptype.py | 4 +- torchgeo/datamodules/ssl4eo_benchmark.py | 6 +-- torchgeo/datamodules/ucmerced.py | 2 +- torchgeo/datamodules/vaihingen.py | 2 +- 25 files changed, 46 insertions(+), 74 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 05dba0da7c1..c0f445a9684 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - python - jupyter - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.11.0 hooks: - id: mypy args: diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index d31eff85320..9e069e11ad6 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -60,7 +60,7 @@ def __init__( }, ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index b2e9848d4f5..3a9082a7f89 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -6,7 +6,6 @@ from typing import Any import kornia.augmentation as K -import torch.nn as nn import torch.nn.functional as F from torch import Tensor @@ -15,36 +14,6 @@ from .geo import GeoDataModule -class _Transform(nn.Module): - """Version of AugmentationSequential designed for samples, not batches.""" - - def __init__(self, aug: nn.Module) -> None: - """Initialize a new _Transform instance. - - Args: - aug: Augmentation to apply. - """ - super().__init__() - self.aug = aug - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply the augmentation. - - Args: - sample: Input sample. - - Returns: - Augmented sample. - """ - for key in ['image', 'mask']: - dtype = sample[key].dtype - # All inputs must be float - sample[key] = sample[key].float() - sample[key] = self.aug(sample[key]) - sample[key] = sample[key].to(dtype) - return sample - - class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. @@ -90,7 +59,11 @@ def __init__( # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 3 - kwargs['transforms'] = _Transform(K.CenterCrop(patch_size, keepdim=True)) + kwargs['transforms'] = K.AugmentationSequential( + K.CenterCrop(patch_size), data_keys=None, keepdim=True + ) + # https://github.com/kornia/kornia/issues/2848 + kwargs['transforms'].keepdim = True super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs @@ -122,7 +95,7 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 920cc1644d2..e40dfe99ae5 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -52,7 +52,7 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index 58855d26fae..41732afbbe3 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -41,7 +41,7 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 0f78fcf5a1e..686c9533b5d 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -73,7 +73,7 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True self.train_aug: Transform | None = None self.val_aug: Transform | None = None self.test_aug: Transform | None = None diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index a7e48226a1a..6937cdfa228 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -61,8 +61,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.predict_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.predict_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index de369fe67ed..9a31ef481ae 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -65,9 +65,9 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] - self.predict_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True + self.predict_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index 0f3ec561990..b07c92e08e6 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -60,7 +60,7 @@ def __init__( }, ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index 2c2ab4c2152..e65ecc1c661 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -60,7 +60,7 @@ def __init__( }, ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 8a5bc9e8012..c171f547850 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -56,9 +56,9 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.val_aug.keepdim = True # type: ignore[attr-defined] - self.test_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.val_aug.keepdim = True + self.test_aug.keepdim = True class LEVIRCDPlusDataModule(NonGeoDataModule): @@ -108,9 +108,9 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.val_aug.keepdim = True # type: ignore[attr-defined] - self.test_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.val_aug.keepdim = True + self.test_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 019d672cc3d..100c9fe7c93 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -65,7 +65,7 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 2c42e877c94..2a81f54bcb7 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -91,7 +91,7 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 0526d5b6026..54d21e871ec 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -54,7 +54,7 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 2316c23b4ae..a200da9ceda 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -9,7 +9,6 @@ import torch from ..datasets import QuakeSet -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -34,7 +33,7 @@ def __init__( :class:`~torchgeo.datasets.QuakeSet`. """ super().__init__(QuakeSet, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), @@ -42,4 +41,4 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 0a046a5efbb..871284b13ca 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -47,4 +47,4 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index b0f59bfe127..0a02bbc3fb6 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -56,7 +56,7 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index b84dd9e755c..66fa90c6349 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -79,8 +79,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 9eeea37d5cd..245b34addc5 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -80,8 +80,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 4897c267c16..a9faf4a1af2 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -79,8 +79,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index b3963b40c6b..9692a8bb2b2 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -78,8 +78,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/southafricacroptype.py b/torchgeo/datamodules/southafricacroptype.py index 45c0534815f..16afccd189a 100644 --- a/torchgeo/datamodules/southafricacroptype.py +++ b/torchgeo/datamodules/southafricacroptype.py @@ -65,8 +65,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index f63e5861550..834f282d9c0 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -64,6 +64,6 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.val_aug.keepdim = True # type: ignore[attr-defined] - self.test_aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.val_aug.keepdim = True + self.test_aug.keepdim = True diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index abc30b29c37..3c785affe6d 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -37,4 +37,4 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 49414358cd1..1bf4240f73d 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -54,7 +54,7 @@ def __init__( keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. From 1e4656086ac9be8b5c3d2ec48dda2dd4671ab870 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 20 Jul 2024 17:23:49 +0400 Subject: [PATCH 12/21] Fix quakeset --- torchgeo/datamodules/quakeset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index a200da9ceda..4cfaf90799a 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -37,7 +37,7 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image'], + data_keys=None, keepdim=True, ) # https://github.com/kornia/kornia/issues/2848 From c0ceb379020c0b5f157b37421ecfdd02799cff30 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 8 Aug 2024 01:04:56 +0400 Subject: [PATCH 13/21] mypy fix --- torchgeo/datasets/inria.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 50208d012a3..70f26ed161b 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -220,7 +220,7 @@ def plot( mask = sample['mask'] if mask.ndim == 3 and mask.shape[0] == 1: mask = mask.squeeze(0) - mask = mask.numpy() + mask_arr = mask.numpy() ncols += 1 if show_predictions: @@ -236,7 +236,7 @@ def plot( axs[0].set_title('Image') if show_mask: - axs[1].imshow(mask, interpolation='none') + axs[1].imshow(mask_arr, interpolation='none') axs[1].axis('off') if show_titles: axs[1].set_title('Label') From d048cc2f6ffe9feb3e03b261916a123418f7c103 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 9 Aug 2024 19:15:53 +0400 Subject: [PATCH 14/21] Fix chesapeake plot test --- tests/datasets/test_chesapeake.py | 5 ++++- torchgeo/datamodules/landcoverai.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index d4540651b87..a6bcdd81b6f 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -212,6 +212,9 @@ def test_plot(self, dataset: ChesapeakeCVPR) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2) + if x['mask'].ndim == 2: + x['prediction'] = x['mask'].clone() + else: + x['prediction'] = x['mask'][0, :, :].clone() dataset.plot(x) plt.close() diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 3c78183ad49..1aa26212f25 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -48,6 +48,7 @@ def __init__( self.train_aug.keepdim = True # type: ignore[attr-defined] self.aug.keepdim = True # type: ignore[attr-defined] + class LandCoverAI100DataModule(NonGeoDataModule): """LightningDataModule implementation for the LandCoverAI100 dataset. From 25b898a949b8ac4594bd3121cb309e39b3735cb4 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 9 Oct 2024 19:01:55 +0400 Subject: [PATCH 15/21] mypy fix --- torchgeo/datamodules/landcoverai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 1aa26212f25..64e5d23b694 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -45,8 +45,8 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] + self.train_aug.keepdim = True + self.aug.keepdim = True class LandCoverAI100DataModule(NonGeoDataModule): @@ -75,4 +75,4 @@ def __init__( ) # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True From 2ebd334875d445cdb86954180db2648bd1d8bf7a Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 5 Nov 2024 19:36:08 +0400 Subject: [PATCH 16/21] Fix SpaceNet aug --- torchgeo/datamodules/spacenet.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 7fccf774279..ee7288ecdb0 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -11,7 +11,6 @@ from torch.utils.data import random_split from ..datasets import SpaceNet, SpaceNet1, SpaceNet6 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -145,16 +144,15 @@ def __init__( self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image'], + data_keys=None, ) # https://github.com/kornia/kornia/issues/2848 self.train_aug.keepdim = True # type: ignore[attr-defined] - self.predict_aug.keepdim = True # type: ignore[attr-defined] + self.predict_aug.keepdim = True # type: ignore[attr-defined] self.aug.keepdim = True # type: ignore[attr-defined] - class SpaceNet6DataModule(SpaceNetBaseDataModule): """LightningDataModule implementation for the SpaceNet6 dataset. From 2f4904fc5a963d4a5525363c19abc18149bf32f7 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 5 Nov 2024 20:22:08 +0400 Subject: [PATCH 17/21] Bump min version of kornia --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 ++-- requirements/min-reqs.old | 2 +- tests/transforms/test_color.py | 2 -- torchgeo/datamodules/agrifieldnet.py | 2 -- torchgeo/datamodules/chesapeake.py | 4 ---- torchgeo/datamodules/deepglobelandcover.py | 2 -- torchgeo/datamodules/fire_risk.py | 2 -- torchgeo/datamodules/geo.py | 3 +-- torchgeo/datamodules/gid15.py | 4 ---- torchgeo/datamodules/inria.py | 5 ----- torchgeo/datamodules/l7irish.py | 2 -- torchgeo/datamodules/l8biome.py | 2 -- torchgeo/datamodules/landcoverai.py | 7 ------- torchgeo/datamodules/levircd.py | 10 ---------- torchgeo/datamodules/naip.py | 2 -- torchgeo/datamodules/oscd.py | 2 -- torchgeo/datamodules/potsdam.py | 2 -- torchgeo/datamodules/quakeset.py | 2 -- torchgeo/datamodules/resisc45.py | 2 -- torchgeo/datamodules/seco.py | 2 -- torchgeo/datamodules/sentinel2_cdl.py | 4 ---- torchgeo/datamodules/sentinel2_eurocrops.py | 4 ---- torchgeo/datamodules/sentinel2_nccm.py | 4 ---- .../datamodules/sentinel2_south_america_soybean.py | 4 ---- torchgeo/datamodules/southafricacroptype.py | 4 ---- torchgeo/datamodules/spacenet.py | 5 ----- torchgeo/datamodules/ssl4eo_benchmark.py | 5 ----- torchgeo/datamodules/ucmerced.py | 2 -- torchgeo/datamodules/vaihingen.py | 2 -- torchgeo/datasets/cyclone.py | 6 ++---- torchgeo/datasets/quakeset.py | 3 +-- torchgeo/datasets/skippd.py | 4 +--- torchgeo/datasets/sustainbench_crop_yield.py | 9 +-------- 34 files changed, 10 insertions(+), 111 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0f445a9684..61e7d2a538d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - python - jupyter - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.0 + rev: v1.13.0 hooks: - id: mypy args: diff --git a/pyproject.toml b/pyproject.toml index b850f8b3466..878a9217e3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", - # kornia 0.7.3+ required for instance segmentation support in AugmentationSequential - "kornia>=0.7.3", + # kornia 0.7.4+ required for using native AugmentationSequential, allowing us to replace our custom implementation + "kornia>=0.7.4", # lightly 1.4.5+ required for LARS optimizer # lightly 1.4.26 is incompatible with the version of timm required by smp # https://github.com/microsoft/torchgeo/issues/1824 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 52d8dcce018..c5a5f6dae58 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,7 +4,7 @@ setuptools==61.0.0 # install einops==0.3.0 fiona==1.8.21 -kornia==0.7.3 +kornia==0.7.4 lightly==1.4.5 lightning[pytorch-extra]==2.0.0 matplotlib==3.5.0 diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index 2cea90b396d..b235f7195f2 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -37,8 +37,6 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> aug = K.AugmentationSequential( RandomGrayscale(weights, p=1), keepdim=True, data_keys=None ) - # https://github.com/kornia/kornia/issues/2848 - aug.keepdim = True output = aug(sample) assert output['image'].shape == sample['image'].shape for i in range(1, 3): diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index 9e069e11ad6..cbb8af25356 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -59,8 +59,6 @@ def __init__( DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 3a9082a7f89..41e944e1af5 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -62,8 +62,6 @@ def __init__( kwargs['transforms'] = K.AugmentationSequential( K.CenterCrop(patch_size), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - kwargs['transforms'].keepdim = True super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs @@ -94,8 +92,6 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index e40dfe99ae5..b3ab2d687b5 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -51,8 +51,6 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index 41732afbbe3..d317981cff3 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -40,8 +40,6 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 686c9533b5d..8721ea6e7f6 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -72,8 +72,7 @@ def __init__( self.aug: Transform = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True + self.train_aug: Transform | None = None self.val_aug: Transform | None = None self.test_aug: Transform | None = None diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 6937cdfa228..d33c55ec829 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -60,10 +60,6 @@ def __init__( keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.predict_aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 9a31ef481ae..797f5484b6a 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -64,11 +64,6 @@ def __init__( keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - self.predict_aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index b07c92e08e6..3a70446f90f 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -59,8 +59,6 @@ def __init__( DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index e65ecc1c661..cf0415b34c9 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -59,8 +59,6 @@ def __init__( DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 64e5d23b694..9ed2a4d34d6 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -44,10 +44,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - class LandCoverAI100DataModule(NonGeoDataModule): """LightningDataModule implementation for the LandCoverAI100 dataset. @@ -73,6 +69,3 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index c171f547850..0e3a124dc94 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -55,11 +55,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.val_aug.keepdim = True - self.test_aug.keepdim = True - class LEVIRCDPlusDataModule(NonGeoDataModule): """LightningDataModule implementation for the LEVIR-CD+ dataset. @@ -107,11 +102,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.val_aug.keepdim = True - self.test_aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 100c9fe7c93..0520d0264ad 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -64,8 +64,6 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 2a81f54bcb7..8db1dd7061a 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -90,8 +90,6 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 54d21e871ec..7a5495a4458 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -53,8 +53,6 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 4cfaf90799a..03c677138e9 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -40,5 +40,3 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 871284b13ca..e279478f8d0 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -46,5 +46,3 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index 0a02bbc3fb6..ecfeb04b288 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -55,8 +55,6 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index 66fa90c6349..91af34b0ef1 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -78,10 +78,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 245b34addc5..8f34c2598ef 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -79,10 +79,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index a9faf4a1af2..34fde0f3153 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -78,10 +78,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index 9692a8bb2b2..d3deff9e823 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -77,10 +77,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/southafricacroptype.py b/torchgeo/datamodules/southafricacroptype.py index 16afccd189a..37fdef5e7db 100644 --- a/torchgeo/datamodules/southafricacroptype.py +++ b/torchgeo/datamodules/southafricacroptype.py @@ -64,10 +64,6 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.aug.keepdim = True - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index ee7288ecdb0..7353efbbaec 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -147,11 +147,6 @@ def __init__( data_keys=None, ) - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True # type: ignore[attr-defined] - self.predict_aug.keepdim = True # type: ignore[attr-defined] - self.aug.keepdim = True # type: ignore[attr-defined] - class SpaceNet6DataModule(SpaceNetBaseDataModule): """LightningDataModule implementation for the SpaceNet6 dataset. diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index 834f282d9c0..02e5de917dd 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -62,8 +62,3 @@ def __init__( data_keys=None, keepdim=True, ) - - # https://github.com/kornia/kornia/issues/2848 - self.train_aug.keepdim = True - self.val_aug.keepdim = True - self.test_aug.keepdim = True diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 3c785affe6d..6bb3e70eab2 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -36,5 +36,3 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 1bf4240f73d..4fead8c85c8 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -53,8 +53,6 @@ def __init__( data_keys=None, keepdim=True, ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 1ff65c86406..2a21832703a 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -96,19 +96,17 @@ def __getitem__(self, index: int) -> dict[str, Any]: Returns: data, labels, field ids, and metadata at that index """ - features = { + sample = { 'relative_time': torch.tensor(self.features.iat[index, 2]), 'ocean': torch.tensor(self.features.iat[index, 3]), 'label': torch.tensor(self.labels.iat[index, 1]), } image_id = self.labels.iat[index, 0] - sample = {'image': self._load_image(image_id)} - sample['label'] = features['label'] + sample['image'] = self._load_image(image_id) if self.transforms is not None: sample = self.transforms(sample) - sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 899b7b664d2..811d79cff08 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -117,11 +117,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: label = torch.tensor(self.data[index]['label']) magnitude = torch.tensor(self.data[index]['magnitude']) - sample = {'image': image, 'label': label} + sample = {'image': image, 'label': label, 'magnitude': magnitude} if self.transforms is not None: sample = self.transforms(sample) - sample['magnitude'] = magnitude return sample diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index a67ee709f9e..3accf32d2af 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -144,12 +144,10 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: data and label at that index """ sample: dict[str, str | Tensor] = {'image': self._load_image(index)} - features = self._load_features(index) - sample['label'] = features['label'] + sample.update(self._load_features(index)) if self.transforms is not None: sample = self.transforms(sample) - sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 002bf87565d..eec9be57ab3 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -149,17 +149,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ sample: dict[str, Tensor] = {'image': self.images[index]} - sample['label'] = self.features[index]['label'] + sample.update(self.features[index]) if self.transforms is not None: sample = self.transforms(sample) - sample.update( - { - x: self.features[index][x] - for x in self.features[index] - if x != 'label' - } - ) return sample From c7dfa97571f41e78e5687f70258fcfe9a45bca4f Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 5 Nov 2024 21:00:57 +0400 Subject: [PATCH 18/21] Remove conditional squeeze --- torchgeo/datasets/inria.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 70f26ed161b..3b2a4348a96 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -217,10 +217,7 @@ def plot( show_predictions = 'prediction' in sample if show_mask: - mask = sample['mask'] - if mask.ndim == 3 and mask.shape[0] == 1: - mask = mask.squeeze(0) - mask_arr = mask.numpy() + mask = sample['mask'].numpy() ncols += 1 if show_predictions: @@ -236,7 +233,7 @@ def plot( axs[0].set_title('Image') if show_mask: - axs[1].imshow(mask_arr, interpolation='none') + axs[1].imshow(mask, interpolation='none') axs[1].axis('off') if show_titles: axs[1].set_title('Label') From 4f42b918c67652d7ecd75fb683831bc02b23a752 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 5 Nov 2024 22:24:06 +0400 Subject: [PATCH 19/21] Fix for new datasets --- torchgeo/datamodules/caffe.py | 11 ++++++----- torchgeo/datamodules/ftw.py | 10 +++++----- torchgeo/datamodules/geonrw.py | 10 ++++++---- torchgeo/datasets/geonrw.py | 2 +- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/torchgeo/datamodules/caffe.py b/torchgeo/datamodules/caffe.py index 241a34197fc..a58136df30b 100644 --- a/torchgeo/datamodules/caffe.py +++ b/torchgeo/datamodules/caffe.py @@ -9,7 +9,6 @@ import torch from ..datasets import CaFFe -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -40,16 +39,18 @@ def __init__( self.size = size - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py index 8f27e719699..a197a789c48 100644 --- a/torchgeo/datamodules/ftw.py +++ b/torchgeo/datamodules/ftw.py @@ -9,7 +9,6 @@ import torch from ..datasets import FieldsOfTheWorld -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -55,16 +54,17 @@ def __init__( self.val_countries = val_countries self.test_countries = test_countries - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/geonrw.py b/torchgeo/datamodules/geonrw.py index c67753673ac..5283b0ed7ac 100644 --- a/torchgeo/datamodules/geonrw.py +++ b/torchgeo/datamodules/geonrw.py @@ -10,7 +10,6 @@ from torch.utils.data import Subset from ..datasets import GeoNRW -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule from .utils import group_shuffle_split @@ -38,14 +37,17 @@ def __init__( """ super().__init__(GeoNRW, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Resize(size), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask']) + self.aug = K.AugmentationSequential( + K.Resize(size), data_keys=None, keepdim=True + ) self.size = size diff --git a/torchgeo/datasets/geonrw.py b/torchgeo/datasets/geonrw.py index dfdadf9f815..50e05bad0fb 100644 --- a/torchgeo/datasets/geonrw.py +++ b/torchgeo/datasets/geonrw.py @@ -242,7 +242,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: # rename to torchgeo standard keys sample['image'] = sample.pop('rgb').float() - sample['mask'] = sample.pop('seg').long() + sample['mask'] = sample.pop('seg').long().squeeze(0) if self.transforms: sample = self.transforms(sample) From 151c343326e665a586a0d67e22c7605bfc523b5b Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 6 Nov 2024 19:37:29 +0400 Subject: [PATCH 20/21] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 878a9217e3f..f01e4933f4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", - # kornia 0.7.4+ required for using native AugmentationSequential, allowing us to replace our custom implementation + # kornia 0.7.4+ required for AugmentationSequential support for unknown keys "kornia>=0.7.4", # lightly 1.4.5+ required for LARS optimizer # lightly 1.4.26 is incompatible with the version of timm required by smp From 2beb9c655f777beb9c06402fea81f2b555ab3b63 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 6 Nov 2024 19:44:49 +0400 Subject: [PATCH 21/21] Update tutorials --- docs/tutorials/transforms.ipynb | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index a7de9f32c69..38a9b825bbf 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -93,7 +93,7 @@ "from torch.utils.data import DataLoader\n", "\n", "from torchgeo.datasets import EuroSAT100\n", - "from torchgeo.transforms import AugmentationSequential, indices" + "from torchgeo.transforms import indices" ] }, { @@ -515,7 +515,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(\n", + "transforms = K.AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -523,7 +523,7 @@ " indices.AppendNDWI(index_green=2, index_nir=7),\n", " K.RandomHorizontalFlip(p=0.5),\n", " K.RandomVerticalFlip(p=0.5),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ")\n", "\n", "batch = next(dataloader)\n", @@ -569,7 +569,7 @@ "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", - "transforms = AugmentationSequential(\n", + "transforms = K.AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -580,10 +580,10 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ")\n", "\n", - "transforms_gpu = AugmentationSequential(\n", + "transforms_gpu = K.AugmentationSequential(\n", " MinMaxNormalize(mins.to(device), maxs.to(device)),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -594,7 +594,7 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ").to(device)\n", "\n", "\n", @@ -664,7 +664,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n", + "transforms = K.AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=None)\n", "dataset = EuroSAT100(root, transforms=transforms)" ] },