From c40d8bc86c8be1d65f94aa7bf42cef98ba8aade2 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 1 Jul 2024 22:19:53 +0400 Subject: [PATCH] Fix for segmentation tests --- torchgeo/datamodules/spacenet.py | 4 ++-- torchgeo/trainers/segmentation.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 3a1972d3926..ac37ae2069d 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -53,12 +53,12 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index afd71521002..6b2601c851f 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn +from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection @@ -225,6 +226,9 @@ def training_step( Returns: The loss tensor. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -245,6 +249,8 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0]