From f364498fd4d971a18760b6fadbaaa1f64dd405b5 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 1 Jul 2024 22:19:53 +0400 Subject: [PATCH] Fix for segmentation tests --- tests/datamodules/test_geo.py | 4 ++-- torchgeo/datamodules/spacenet.py | 4 ++-- torchgeo/trainers/segmentation.py | 8 ++++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index b4944ba7f6c..e0b66ab1984 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -32,7 +32,7 @@ def __init__( def __getitem__(self, query: BoundingBox) -> dict[str, Any]: image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2) - return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query} + return {'image': image, 'crs': CRS.from_epsg(4326), 'bbox': query} def plot(self, *args: Any, **kwargs: Any) -> Figure: return plt.figure() @@ -68,7 +68,7 @@ def __init__( self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: - return {"image": torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} + return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} def __len__(self) -> int: return self.length diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 3a1972d3926..ac37ae2069d 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -53,12 +53,12 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=['image', 'mask'], + data_keys=None, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index afd71521002..6b2601c851f 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn +from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection @@ -225,6 +226,9 @@ def training_step( Returns: The loss tensor. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -245,6 +249,8 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0]