From 4f42b918c67652d7ecd75fb683831bc02b23a752 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 5 Nov 2024 22:24:06 +0400 Subject: [PATCH] Fix for new datasets --- torchgeo/datamodules/caffe.py | 11 ++++++----- torchgeo/datamodules/ftw.py | 10 +++++----- torchgeo/datamodules/geonrw.py | 10 ++++++---- torchgeo/datasets/geonrw.py | 2 +- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/torchgeo/datamodules/caffe.py b/torchgeo/datamodules/caffe.py index 241a34197fc..a58136df30b 100644 --- a/torchgeo/datamodules/caffe.py +++ b/torchgeo/datamodules/caffe.py @@ -9,7 +9,6 @@ import torch from ..datasets import CaFFe -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -40,16 +39,18 @@ def __init__( self.size = size - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py index 8f27e719699..a197a789c48 100644 --- a/torchgeo/datamodules/ftw.py +++ b/torchgeo/datamodules/ftw.py @@ -9,7 +9,6 @@ import torch from ..datasets import FieldsOfTheWorld -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -55,16 +54,17 @@ def __init__( self.val_countries = val_countries self.test_countries = test_countries - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/geonrw.py b/torchgeo/datamodules/geonrw.py index c67753673ac..5283b0ed7ac 100644 --- a/torchgeo/datamodules/geonrw.py +++ b/torchgeo/datamodules/geonrw.py @@ -10,7 +10,6 @@ from torch.utils.data import Subset from ..datasets import GeoNRW -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule from .utils import group_shuffle_split @@ -38,14 +37,17 @@ def __init__( """ super().__init__(GeoNRW, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Resize(size), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask']) + self.aug = K.AugmentationSequential( + K.Resize(size), data_keys=None, keepdim=True + ) self.size = size diff --git a/torchgeo/datasets/geonrw.py b/torchgeo/datasets/geonrw.py index dfdadf9f815..50e05bad0fb 100644 --- a/torchgeo/datasets/geonrw.py +++ b/torchgeo/datasets/geonrw.py @@ -242,7 +242,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: # rename to torchgeo standard keys sample['image'] = sample.pop('rgb').float() - sample['mask'] = sample.pop('seg').long() + sample['mask'] = sample.pop('seg').long().squeeze(0) if self.transforms: sample = self.transforms(sample)