Skip to content

Commit

Permalink
datamodules: Switch to kornia AugmentationSequential (#2147)
Browse files Browse the repository at this point in the history
* Switch to kornia AugmentationSequential in datamodules

* Fix for classification & regression tests

* Fix for segmentation tests

* Revert inria batch_size

* Set keepdim=True

* Fix ssl4eo_benchmark

* Masks must not have a channel dimension

* Fix south africa dimensions

* Fix chesapeake dimensions

* Remove unused import

* mypy fixes & remove _Transform

* Fix quakeset

* mypy fix

* Fix chesapeake plot test

* mypy fix

* Fix SpaceNet aug

* Bump min version of kornia

* Remove conditional squeeze

* Fix for new datasets

* Update pyproject.toml

* Update tutorials

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
ashnair1 and adamjstewart authored Nov 6, 2024
1 parent 51c8d36 commit 69f0c70
Show file tree
Hide file tree
Showing 41 changed files with 164 additions and 191 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- python
- jupyter
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.13.0
hooks:
- id: mypy
args:
Expand Down
16 changes: 8 additions & 8 deletions docs/tutorials/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"from torch.utils.data import DataLoader\n",
"\n",
"from torchgeo.datasets import EuroSAT100\n",
"from torchgeo.transforms import AugmentationSequential, indices"
"from torchgeo.transforms import indices"
]
},
{
Expand Down Expand Up @@ -515,15 +515,15 @@
},
"outputs": [],
"source": [
"transforms = AugmentationSequential(\n",
"transforms = K.AugmentationSequential(\n",
" MinMaxNormalize(mins, maxs),\n",
" indices.AppendNDBI(index_swir=11, index_nir=7),\n",
" indices.AppendNDSI(index_green=3, index_swir=11),\n",
" indices.AppendNDVI(index_nir=7, index_red=3),\n",
" indices.AppendNDWI(index_green=2, index_nir=7),\n",
" K.RandomHorizontalFlip(p=0.5),\n",
" K.RandomVerticalFlip(p=0.5),\n",
" data_keys=['image'],\n",
" data_keys=None,\n",
")\n",
"\n",
"batch = next(dataloader)\n",
Expand Down Expand Up @@ -569,7 +569,7 @@
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"transforms = AugmentationSequential(\n",
"transforms = K.AugmentationSequential(\n",
" MinMaxNormalize(mins, maxs),\n",
" indices.AppendNDBI(index_swir=11, index_nir=7),\n",
" indices.AppendNDSI(index_green=3, index_swir=11),\n",
Expand All @@ -580,10 +580,10 @@
" K.RandomAffine(degrees=(0, 90), p=0.25),\n",
" K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n",
" K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n",
" data_keys=['image'],\n",
" data_keys=None,\n",
")\n",
"\n",
"transforms_gpu = AugmentationSequential(\n",
"transforms_gpu = K.AugmentationSequential(\n",
" MinMaxNormalize(mins.to(device), maxs.to(device)),\n",
" indices.AppendNDBI(index_swir=11, index_nir=7),\n",
" indices.AppendNDSI(index_green=3, index_swir=11),\n",
Expand All @@ -594,7 +594,7 @@
" K.RandomAffine(degrees=(0, 90), p=0.25),\n",
" K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n",
" K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n",
" data_keys=['image'],\n",
" data_keys=None,\n",
").to(device)\n",
"\n",
"\n",
Expand Down Expand Up @@ -664,7 +664,7 @@
},
"outputs": [],
"source": [
"transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n",
"transforms = K.AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=None)\n",
"dataset = EuroSAT100(root, transforms=transforms)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ dependencies = [
"einops>=0.3",
# fiona 1.8.21+ required for Python 3.10 wheels
"fiona>=1.8.21",
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
"kornia>=0.7.3",
# kornia 0.7.4+ required for AugmentationSequential support for unknown keys
"kornia>=0.7.4",
# lightly 1.4.5+ required for LARS optimizer
# lightly 1.4.26 is incompatible with the version of timm required by smp
# https://github.com/microsoft/torchgeo/issues/1824
Expand Down
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ setuptools==61.0.0
# install
einops==0.3.0
fiona==1.8.21
kornia==0.7.3
kornia==0.7.4
lightly==1.4.5
lightning[pytorch-extra]==2.0.0
matplotlib==3.5.0
Expand Down
4 changes: 2 additions & 2 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self.res = 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)}
return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}

def __len__(self) -> int:
return self.length
Expand Down
5 changes: 4 additions & 1 deletion tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def test_plot(self, dataset: ChesapeakeCVPR) -> None:
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2)
if x['mask'].ndim == 2:
x['prediction'] = x['mask'].clone()
else:
x['prediction'] = x['mask'][0, :, :].clone()
dataset.plot(x)
plt.close()
2 changes: 0 additions & 2 deletions tests/transforms/test_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) ->
aug = K.AugmentationSequential(
RandomGrayscale(weights, p=1), keepdim=True, data_keys=None
)
# https://github.com/kornia/kornia/issues/2848
aug.keepdim = True
output = aug(sample)
assert output['image'].shape == sample['image'].shape
for i in range(1, 3):
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..datasets import AgriFieldNet, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


Expand Down Expand Up @@ -49,12 +48,13 @@ def __init__(
**kwargs,
)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
Expand Down
11 changes: 6 additions & 5 deletions torchgeo/datamodules/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import CaFFe
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -40,16 +39,18 @@ def __init__(

self.size = size

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

self.aug = AugmentationSequential(
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Resize(size),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)
43 changes: 5 additions & 38 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,14 @@
from typing import Any

import kornia.augmentation as K
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from ..datasets import ChesapeakeCVPR
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class _Transform(nn.Module):
"""Version of AugmentationSequential designed for samples, not batches."""

def __init__(self, aug: nn.Module) -> None:
"""Initialize a new _Transform instance.
Args:
aug: Augmentation to apply.
"""
super().__init__()
self.aug = aug

def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
"""Apply the augmentation.
Args:
sample: Input sample.
Returns:
Augmented sample.
"""
for key in ['image', 'mask']:
dtype = sample[key].dtype
# All inputs must be float
sample[key] = sample[key].float()
sample[key] = self.aug(sample[key])
sample[key] = sample[key].to(dtype)
# Kornia adds batch dimension
sample[key] = rearrange(sample[key], '() c h w -> c h w')
return sample


class ChesapeakeCVPRDataModule(GeoDataModule):
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
Expand Down Expand Up @@ -94,7 +59,9 @@ def __init__(
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = patch_size * 3
kwargs['transforms'] = _Transform(K.CenterCrop(patch_size))
kwargs['transforms'] = K.AugmentationSequential(
K.CenterCrop(patch_size), data_keys=None, keepdim=True
)

super().__init__(
ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs
Expand Down Expand Up @@ -122,8 +89,8 @@ def __init__(
else:
self.layers = ['naip-new', 'lc']

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from ..datasets import DeepGlobeLandCover
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _RandomNCrop
from .geo import NonGeoDataModule

Expand Down Expand Up @@ -46,10 +45,11 @@ def __init__(
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct

self.aug = AugmentationSequential(
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/fire_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import kornia.augmentation as K

from ..datasets import FireRisk
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand All @@ -30,15 +29,16 @@ def __init__(
:class:`~torchgeo.datasets.FireRisk`.
"""
super().__init__(FireRisk, batch_size, num_workers, **kwargs)
self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
K.RandomErasing(p=0.1),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=['image'],
data_keys=None,
keepdim=True,
)

def setup(self, stage: str) -> None:
Expand Down
10 changes: 5 additions & 5 deletions torchgeo/datamodules/ftw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import FieldsOfTheWorld
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -55,16 +54,17 @@ def __init__(
self.val_countries = val_countries
self.test_countries = test_countries

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
GridGeoSampler,
RandomBatchGeoSampler,
)
from ..transforms import AugmentationSequential
from .utils import MisconfigurationException


Expand Down Expand Up @@ -70,9 +69,10 @@ def __init__(

# Data augmentation
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
self.aug: Transform = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image']
self.aug: Transform = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

self.train_aug: Transform | None = None
self.val_aug: Transform | None = None
self.test_aug: Transform | None = None
Expand Down
10 changes: 6 additions & 4 deletions torchgeo/datamodules/geonrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.utils.data import Subset

from ..datasets import GeoNRW
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import group_shuffle_split

Expand Down Expand Up @@ -38,14 +37,17 @@ def __init__(
"""
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])
self.aug = K.AugmentationSequential(
K.Resize(size), data_keys=None, keepdim=True
)

self.size = size

Expand Down
Loading

0 comments on commit 69f0c70

Please sign in to comment.