Skip to content

Commit

Permalink
mypy fixes & remove _Transform
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 20, 2024
1 parent 3ff4cba commit d2b5b01
Show file tree
Hide file tree
Showing 27 changed files with 50 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- python
- jupyter
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.11.0
hooks:
- id: mypy
args:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
},
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
39 changes: 6 additions & 33 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any

import kornia.augmentation as K
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

Expand All @@ -15,36 +14,6 @@
from .geo import GeoDataModule


class _Transform(nn.Module):
"""Version of AugmentationSequential designed for samples, not batches."""

def __init__(self, aug: nn.Module) -> None:
"""Initialize a new _Transform instance.
Args:
aug: Augmentation to apply.
"""
super().__init__()
self.aug = aug

def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
"""Apply the augmentation.
Args:
sample: Input sample.
Returns:
Augmented sample.
"""
for key in ['image', 'mask']:
dtype = sample[key].dtype
# All inputs must be float
sample[key] = sample[key].float()
sample[key] = self.aug(sample[key])
sample[key] = sample[key].to(dtype)
return sample


class ChesapeakeCVPRDataModule(GeoDataModule):
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
Expand Down Expand Up @@ -90,7 +59,11 @@ def __init__(
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = patch_size * 3
kwargs['transforms'] = _Transform(K.CenterCrop(patch_size, keepdim=True))
kwargs['transforms'] = K.AugmentationSequential(
K.CenterCrop(patch_size), data_keys=None, keepdim=True
)
# https://github.com/kornia/kornia/issues/2848
kwargs['transforms'].keepdim = True

super().__init__(
ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs
Expand Down Expand Up @@ -122,7 +95,7 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/fire_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True
self.train_aug: Transform | None = None
self.val_aug: Transform | None = None
self.test_aug: Transform | None = None
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/gid15.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.predict_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.predict_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.predict_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True
self.predict_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
},
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
},
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True
12 changes: 6 additions & 6 deletions torchgeo/datamodules/levircd.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.val_aug.keepdim = True # type: ignore[attr-defined]
self.test_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.val_aug.keepdim = True
self.test_aug.keepdim = True


class LEVIRCDPlusDataModule(NonGeoDataModule):
Expand Down Expand Up @@ -108,9 +108,9 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.val_aug.keepdim = True # type: ignore[attr-defined]
self.test_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.val_aug.keepdim = True
self.test_aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datamodules/quakeset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import QuakeSet
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand All @@ -34,12 +33,12 @@ def __init__(
:class:`~torchgeo.datasets.QuakeSet`.
"""
super().__init__(QuakeSet, batch_size, num_workers, **kwargs)
self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image'],
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
2 changes: 1 addition & 1 deletion torchgeo/datamodules/resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
2 changes: 1 addition & 1 deletion torchgeo/datamodules/seco.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/sentinel2_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/sentinel2_eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/sentinel2_nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/sentinel2_south_america_soybean.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/southafricacroptype.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/ssl4eo_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ def __init__(
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.val_aug.keepdim = True # type: ignore[attr-defined]
self.test_aug.keepdim = True # type: ignore[attr-defined]
self.train_aug.keepdim = True
self.val_aug.keepdim = True
self.test_aug.keepdim = True
2 changes: 1 addition & 1 deletion torchgeo/datamodules/ucmerced.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True
2 changes: 1 addition & 1 deletion torchgeo/datamodules/vaihingen.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down

0 comments on commit d2b5b01

Please sign in to comment.