Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datamodules: Switch to kornia AugmentationSequential #2147

Merged
merged 21 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- python
- jupyter
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.13.0
hooks:
- id: mypy
args:
Expand Down
16 changes: 8 additions & 8 deletions docs/tutorials/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"from torch.utils.data import DataLoader\n",
"\n",
"from torchgeo.datasets import EuroSAT100\n",
"from torchgeo.transforms import AugmentationSequential, indices"
"from torchgeo.transforms import indices"
]
},
{
Expand Down Expand Up @@ -515,15 +515,15 @@
},
"outputs": [],
"source": [
"transforms = AugmentationSequential(\n",
"transforms = K.AugmentationSequential(\n",
" MinMaxNormalize(mins, maxs),\n",
" indices.AppendNDBI(index_swir=11, index_nir=7),\n",
" indices.AppendNDSI(index_green=3, index_swir=11),\n",
" indices.AppendNDVI(index_nir=7, index_red=3),\n",
" indices.AppendNDWI(index_green=2, index_nir=7),\n",
" K.RandomHorizontalFlip(p=0.5),\n",
" K.RandomVerticalFlip(p=0.5),\n",
" data_keys=['image'],\n",
" data_keys=None,\n",
")\n",
"\n",
"batch = next(dataloader)\n",
Expand Down Expand Up @@ -569,7 +569,7 @@
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"transforms = AugmentationSequential(\n",
"transforms = K.AugmentationSequential(\n",
" MinMaxNormalize(mins, maxs),\n",
" indices.AppendNDBI(index_swir=11, index_nir=7),\n",
" indices.AppendNDSI(index_green=3, index_swir=11),\n",
Expand All @@ -580,10 +580,10 @@
" K.RandomAffine(degrees=(0, 90), p=0.25),\n",
" K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n",
" K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n",
" data_keys=['image'],\n",
" data_keys=None,\n",
")\n",
"\n",
"transforms_gpu = AugmentationSequential(\n",
"transforms_gpu = K.AugmentationSequential(\n",
" MinMaxNormalize(mins.to(device), maxs.to(device)),\n",
" indices.AppendNDBI(index_swir=11, index_nir=7),\n",
" indices.AppendNDSI(index_green=3, index_swir=11),\n",
Expand All @@ -594,7 +594,7 @@
" K.RandomAffine(degrees=(0, 90), p=0.25),\n",
" K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n",
" K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n",
" data_keys=['image'],\n",
" data_keys=None,\n",
").to(device)\n",
"\n",
"\n",
Expand Down Expand Up @@ -664,7 +664,7 @@
},
"outputs": [],
"source": [
"transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n",
"transforms = K.AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=None)\n",
"dataset = EuroSAT100(root, transforms=transforms)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ dependencies = [
"einops>=0.3",
# fiona 1.8.21+ required for Python 3.10 wheels
"fiona>=1.8.21",
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
"kornia>=0.7.3",
# kornia 0.7.4+ required for AugmentationSequential support for unknown keys
"kornia>=0.7.4",
# lightly 1.4.5+ required for LARS optimizer
# lightly 1.4.26 is incompatible with the version of timm required by smp
# https://github.com/microsoft/torchgeo/issues/1824
Expand Down
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ setuptools==61.0.0
# install
einops==0.3.0
fiona==1.8.21
kornia==0.7.3
kornia==0.7.4
lightly==1.4.5
lightning[pytorch-extra]==2.0.0
matplotlib==3.5.0
Expand Down
4 changes: 2 additions & 2 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self.res = 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)}
return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}

def __len__(self) -> int:
return self.length
Expand Down
5 changes: 4 additions & 1 deletion tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def test_plot(self, dataset: ChesapeakeCVPR) -> None:
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2)
if x['mask'].ndim == 2:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
x['prediction'] = x['mask'].clone()
else:
x['prediction'] = x['mask'][0, :, :].clone()
dataset.plot(x)
plt.close()
2 changes: 0 additions & 2 deletions tests/transforms/test_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) ->
aug = K.AugmentationSequential(
RandomGrayscale(weights, p=1), keepdim=True, data_keys=None
)
# https://github.com/kornia/kornia/issues/2848
aug.keepdim = True
output = aug(sample)
assert output['image'].shape == sample['image'].shape
for i in range(1, 3):
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..datasets import AgriFieldNet, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


Expand Down Expand Up @@ -49,12 +48,13 @@ def __init__(
**kwargs,
)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
Expand Down
11 changes: 6 additions & 5 deletions torchgeo/datamodules/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import CaFFe
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -40,16 +39,18 @@ def __init__(

self.size = size

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

self.aug = AugmentationSequential(
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Resize(size),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)
43 changes: 5 additions & 38 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,14 @@
from typing import Any

import kornia.augmentation as K
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from ..datasets import ChesapeakeCVPR
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class _Transform(nn.Module):
"""Version of AugmentationSequential designed for samples, not batches."""

def __init__(self, aug: nn.Module) -> None:
"""Initialize a new _Transform instance.

Args:
aug: Augmentation to apply.
"""
super().__init__()
self.aug = aug

def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
"""Apply the augmentation.

Args:
sample: Input sample.

Returns:
Augmented sample.
"""
for key in ['image', 'mask']:
dtype = sample[key].dtype
# All inputs must be float
sample[key] = sample[key].float()
sample[key] = self.aug(sample[key])
sample[key] = sample[key].to(dtype)
# Kornia adds batch dimension
sample[key] = rearrange(sample[key], '() c h w -> c h w')
return sample


class ChesapeakeCVPRDataModule(GeoDataModule):
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.

Expand Down Expand Up @@ -94,7 +59,9 @@ def __init__(
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = patch_size * 3
kwargs['transforms'] = _Transform(K.CenterCrop(patch_size))
kwargs['transforms'] = K.AugmentationSequential(
K.CenterCrop(patch_size), data_keys=None, keepdim=True
)

super().__init__(
ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs
Expand Down Expand Up @@ -122,8 +89,8 @@ def __init__(
else:
self.layers = ['naip-new', 'lc']

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from ..datasets import DeepGlobeLandCover
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _RandomNCrop
from .geo import NonGeoDataModule

Expand Down Expand Up @@ -46,10 +45,11 @@ def __init__(
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct

self.aug = AugmentationSequential(
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/fire_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import kornia.augmentation as K

from ..datasets import FireRisk
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand All @@ -30,15 +29,16 @@ def __init__(
:class:`~torchgeo.datasets.FireRisk`.
"""
super().__init__(FireRisk, batch_size, num_workers, **kwargs)
self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
K.RandomErasing(p=0.1),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=['image'],
data_keys=None,
keepdim=True,
)

def setup(self, stage: str) -> None:
Expand Down
10 changes: 5 additions & 5 deletions torchgeo/datamodules/ftw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import FieldsOfTheWorld
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -55,16 +54,17 @@ def __init__(
self.val_countries = val_countries
self.test_countries = test_countries

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
GridGeoSampler,
RandomBatchGeoSampler,
)
from ..transforms import AugmentationSequential
from .utils import MisconfigurationException


Expand Down Expand Up @@ -70,9 +69,10 @@ def __init__(

# Data augmentation
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
self.aug: Transform = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image']
self.aug: Transform = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

self.train_aug: Transform | None = None
self.val_aug: Transform | None = None
self.test_aug: Transform | None = None
Expand Down
10 changes: 6 additions & 4 deletions torchgeo/datamodules/geonrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.utils.data import Subset

from ..datasets import GeoNRW
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import group_shuffle_split

Expand Down Expand Up @@ -38,14 +37,17 @@ def __init__(
"""
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])
self.aug = K.AugmentationSequential(
K.Resize(size), data_keys=None, keepdim=True
)

self.size = size

Expand Down
Loading