diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 05dba0da7c1..61e7d2a538d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - python - jupyter - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.13.0 hooks: - id: mypy args: diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index a7de9f32c69..38a9b825bbf 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -93,7 +93,7 @@ "from torch.utils.data import DataLoader\n", "\n", "from torchgeo.datasets import EuroSAT100\n", - "from torchgeo.transforms import AugmentationSequential, indices" + "from torchgeo.transforms import indices" ] }, { @@ -515,7 +515,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(\n", + "transforms = K.AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -523,7 +523,7 @@ " indices.AppendNDWI(index_green=2, index_nir=7),\n", " K.RandomHorizontalFlip(p=0.5),\n", " K.RandomVerticalFlip(p=0.5),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ")\n", "\n", "batch = next(dataloader)\n", @@ -569,7 +569,7 @@ "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", - "transforms = AugmentationSequential(\n", + "transforms = K.AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -580,10 +580,10 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ")\n", "\n", - "transforms_gpu = AugmentationSequential(\n", + "transforms_gpu = K.AugmentationSequential(\n", " MinMaxNormalize(mins.to(device), maxs.to(device)),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -594,7 +594,7 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ").to(device)\n", "\n", "\n", @@ -664,7 +664,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n", + "transforms = K.AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=None)\n", "dataset = EuroSAT100(root, transforms=transforms)" ] }, diff --git a/pyproject.toml b/pyproject.toml index b850f8b3466..f01e4933f4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", - # kornia 0.7.3+ required for instance segmentation support in AugmentationSequential - "kornia>=0.7.3", + # kornia 0.7.4+ required for AugmentationSequential support for unknown keys + "kornia>=0.7.4", # lightly 1.4.5+ required for LARS optimizer # lightly 1.4.26 is incompatible with the version of timm required by smp # https://github.com/microsoft/torchgeo/issues/1824 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 52d8dcce018..c5a5f6dae58 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,7 +4,7 @@ setuptools==61.0.0 # install einops==0.3.0 fiona==1.8.21 -kornia==0.7.3 +kornia==0.7.4 lightly==1.4.5 lightning[pytorch-extra]==2.0.0 matplotlib==3.5.0 diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 8e5fd13d292..4e5431c684f 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -31,7 +31,7 @@ def __init__( self.res = 1 def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - image = torch.arange(3 * 2 * 2).view(3, 2, 2) + image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2) return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query} def plot(self, *args: Any, **kwargs: Any) -> Figure: @@ -68,7 +68,7 @@ def __init__( self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: - return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)} + return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} def __len__(self) -> int: return self.length diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index d4540651b87..a6bcdd81b6f 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -212,6 +212,9 @@ def test_plot(self, dataset: ChesapeakeCVPR) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2) + if x['mask'].ndim == 2: + x['prediction'] = x['mask'].clone() + else: + x['prediction'] = x['mask'][0, :, :].clone() dataset.plot(x) plt.close() diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index 2cea90b396d..b235f7195f2 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -37,8 +37,6 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> aug = K.AugmentationSequential( RandomGrayscale(weights, p=1), keepdim=True, data_keys=None ) - # https://github.com/kornia/kornia/issues/2848 - aug.keepdim = True output = aug(sample) assert output['image'].shape == sample['image'].shape for i in range(1, 3): diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index c5b92b6b01a..cbb8af25356 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -12,7 +12,6 @@ from ..datasets import AgriFieldNet, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,13 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/caffe.py b/torchgeo/datamodules/caffe.py index 241a34197fc..a58136df30b 100644 --- a/torchgeo/datamodules/caffe.py +++ b/torchgeo/datamodules/caffe.py @@ -9,7 +9,6 @@ import torch from ..datasets import CaFFe -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -40,16 +39,18 @@ def __init__( self.size = size - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 37a0d32edbd..41e944e1af5 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -6,49 +6,14 @@ from typing import Any import kornia.augmentation as K -import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import Tensor from ..datasets import ChesapeakeCVPR from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential from .geo import GeoDataModule -class _Transform(nn.Module): - """Version of AugmentationSequential designed for samples, not batches.""" - - def __init__(self, aug: nn.Module) -> None: - """Initialize a new _Transform instance. - - Args: - aug: Augmentation to apply. - """ - super().__init__() - self.aug = aug - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply the augmentation. - - Args: - sample: Input sample. - - Returns: - Augmented sample. - """ - for key in ['image', 'mask']: - dtype = sample[key].dtype - # All inputs must be float - sample[key] = sample[key].float() - sample[key] = self.aug(sample[key]) - sample[key] = sample[key].to(dtype) - # Kornia adds batch dimension - sample[key] = rearrange(sample[key], '() c h w -> c h w') - return sample - - class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. @@ -94,7 +59,9 @@ def __init__( # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 3 - kwargs['transforms'] = _Transform(K.CenterCrop(patch_size)) + kwargs['transforms'] = K.AugmentationSequential( + K.CenterCrop(patch_size), data_keys=None, keepdim=True + ) super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs @@ -122,8 +89,8 @@ def __init__( else: self.layers = ['naip-new', 'lc'] - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 80ea6052cb7..b3ab2d687b5 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -11,7 +11,6 @@ from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -46,10 +45,11 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index 1a0d6c7c047..d317981cff3 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import FireRisk -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -30,7 +29,7 @@ def __init__( :class:`~torchgeo.datasets.FireRisk`. """ super().__init__(FireRisk, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -38,7 +37,8 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py index 8f27e719699..a197a789c48 100644 --- a/torchgeo/datamodules/ftw.py +++ b/torchgeo/datamodules/ftw.py @@ -9,7 +9,6 @@ import torch from ..datasets import FieldsOfTheWorld -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -55,16 +54,17 @@ def __init__( self.val_countries = val_countries self.test_countries = test_countries - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e8e3aedd194..8721ea6e7f6 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -20,7 +20,6 @@ GridGeoSampler, RandomBatchGeoSampler, ) -from ..transforms import AugmentationSequential from .utils import MisconfigurationException @@ -70,9 +69,10 @@ def __init__( # Data augmentation Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] - self.aug: Transform = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image'] + self.aug: Transform = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + self.train_aug: Transform | None = None self.val_aug: Transform | None = None self.test_aug: Transform | None = None diff --git a/torchgeo/datamodules/geonrw.py b/torchgeo/datamodules/geonrw.py index c67753673ac..5283b0ed7ac 100644 --- a/torchgeo/datamodules/geonrw.py +++ b/torchgeo/datamodules/geonrw.py @@ -10,7 +10,6 @@ from torch.utils.data import Subset from ..datasets import GeoNRW -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule from .utils import group_shuffle_split @@ -38,14 +37,17 @@ def __init__( """ super().__init__(GeoNRW, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Resize(size), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask']) + self.aug = K.AugmentationSequential( + K.Resize(size), data_keys=None, keepdim=True + ) self.size = size diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index fc4e802c148..d33c55ec829 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -11,7 +11,6 @@ from ..datasets import GID15 from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,15 +47,17 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = self.val_aug = AugmentationSequential( + self.train_aug = self.val_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 39e8ede22c5..797f5484b6a 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -9,7 +9,6 @@ from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -44,22 +43,25 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index 35408feddbb..3a70446f90f 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -12,7 +12,6 @@ from ..datasets import L7Irish, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,13 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index ddc802a5ce3..cf0415b34c9 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -12,7 +12,6 @@ from ..datasets import L8Biome, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,13 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 4cbbf1a5e2a..9ed2a4d34d6 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import LandCoverAI, LandCoverAI100 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -31,17 +30,18 @@ def __init__( """ super().__init__(LandCoverAI, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) @@ -66,6 +66,6 @@ def __init__( """ super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 8488c2e58b3..0e3a124dc94 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -11,7 +11,6 @@ from ..datasets import LEVIRCD, LEVIRCDPlus from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -43,18 +42,17 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, + keepdim=True, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.val_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - self.test_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.test_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) @@ -91,18 +89,17 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, + keepdim=True, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.val_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - self.test_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.test_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 1be0a64cff5..0520d0264ad 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -20,7 +20,6 @@ ChesapeakeWV, ) from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -62,8 +61,8 @@ def __init__( NAIP, batch_size, patch_size, length, num_workers, **self.naip_kwargs ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 87ae20cdf40..8db1dd7061a 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -11,7 +11,6 @@ from ..datasets import OSCD from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -85,10 +84,11 @@ def __init__( self.mean = torch.tensor([MEAN[b] for b in self.bands]) self.std = torch.tensor([STD[b] for b in self.bands]) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 8011382c769..7a5495a4458 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -11,7 +11,6 @@ from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,10 +47,11 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 1a3e19a5122..03c677138e9 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -9,7 +9,6 @@ import torch from ..datasets import QuakeSet -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -34,9 +33,10 @@ def __init__( :class:`~torchgeo.datasets.QuakeSet`. """ super().__init__(QuakeSet, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index e88e139f481..e279478f8d0 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -9,7 +9,6 @@ import torch from ..datasets import RESISC45 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -36,7 +35,7 @@ def __init__( """ super().__init__(RESISC45, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -44,5 +43,6 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index 1160f037366..ecfeb04b288 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -10,7 +10,6 @@ from einops import repeat from ..datasets import SeasonalContrastS2 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -49,11 +48,12 @@ def __init__( _mean = repeat(_mean, 'c -> (t c)', t=seasons) _std = repeat(_std, 'c -> (t c)', t=seasons) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index 97c3d05392e..91af34b0ef1 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -13,7 +13,6 @@ from ..datasets import CDL, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -63,19 +62,20 @@ def __init__( CDL, batch_size, patch_size, length, num_workers, **self.cdl_kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 4e0893e4f8d..8f34c2598ef 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -13,7 +13,6 @@ from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -64,19 +63,20 @@ def __init__( **self.eurocrops_kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 91b4f936fdc..34fde0f3153 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -13,7 +13,6 @@ from ..datasets import NCCM, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -63,19 +62,20 @@ def __init__( NCCM, batch_size, patch_size, length, num_workers, **self.nccm_kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index e3363e857f5..d3deff9e823 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -14,7 +14,6 @@ from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -62,19 +61,20 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/southafricacroptype.py b/torchgeo/datamodules/southafricacroptype.py index 3f44bb61471..37fdef5e7db 100644 --- a/torchgeo/datamodules/southafricacroptype.py +++ b/torchgeo/datamodules/southafricacroptype.py @@ -12,7 +12,6 @@ from ..datasets import SouthAfricaCropType, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,19 +48,20 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index e6b0b325823..7353efbbaec 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -11,7 +11,6 @@ from torch.utils.data import random_split from ..datasets import SpaceNet, SpaceNet1, SpaceNet6 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -124,7 +123,7 @@ def __init__( SpaceNet1, batch_size, num_workers, val_split_pct, test_split_pct, **kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), K.RandomRotation(p=0.5, degrees=90), @@ -132,18 +131,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image'], + data_keys=None, ) diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index c9eb1d2e315..02e5de917dd 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -10,7 +10,6 @@ from ..datasets import SSL4EOLBenchmark from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -40,23 +39,26 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.val_aug = AugmentationSequential( + self.val_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.test_aug = AugmentationSequential( + self.test_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 59bb49444ee..6bb3e70eab2 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import UCMerced -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -31,8 +30,9 @@ def __init__( """ super().__init__(UCMerced, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size=256), - data_keys=['image'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 98bc4945e95..4fead8c85c8 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -11,7 +11,6 @@ from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,10 +47,11 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 7e521092aaf..af9b01f54dc 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -581,7 +581,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample['mask'] = np.concatenate(sample['mask'], axis=0) sample['image'] = torch.from_numpy(sample['image']).float() - sample['mask'] = torch.from_numpy(sample['mask']).long() + sample['mask'] = torch.from_numpy(sample['mask']).long().squeeze(0) if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d86bc58e6ed..26a035d427d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -554,7 +554,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if self.is_image: sample['image'] = data else: - sample['mask'] = data + sample['mask'] = data.squeeze(0) if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/geonrw.py b/torchgeo/datasets/geonrw.py index dfdadf9f815..50e05bad0fb 100644 --- a/torchgeo/datasets/geonrw.py +++ b/torchgeo/datasets/geonrw.py @@ -242,7 +242,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: # rename to torchgeo standard keys sample['image'] = sample.pop('rgb').float() - sample['mask'] = sample.pop('seg').long() + sample['mask'] = sample.pop('seg').long().squeeze(0) if self.transforms: sample = self.transforms(sample) diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 3c4f7f895ec..a8643873c5b 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -223,7 +223,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: ) mask_filepaths.append(file_path) - mask = self._merge_files(mask_filepaths, query) + mask = self._merge_files(mask_filepaths, query).squeeze(0) sample = { 'crs': self.crs, diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 7bc4f828974..13c5a8474c4 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -324,7 +324,7 @@ def _load_mask(self, path: Path) -> Tensor: mask """ with rasterio.open(path) as src: - mask = torch.from_numpy(src.read()).long() + mask = torch.from_numpy(src.read(1)).long() mask = self.ordinal_map[mask] return mask