diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 3a1972d3926..ac37ae2069d 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -53,12 +53,12 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index afd71521002..6b2601c851f 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn +from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection @@ -225,6 +226,9 @@ def training_step( Returns: The loss tensor. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -245,6 +249,8 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0]