From d72c56a0de5e8847bad6f81c431745e367763d10 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 20 Jul 2024 16:17:54 +0400 Subject: [PATCH] Fix dataset sample --- torchgeo/datasets/cyclone.py | 3 ++- torchgeo/datasets/quakeset.py | 3 ++- torchgeo/datasets/skippd.py | 3 ++- torchgeo/datasets/sustainbench_crop_yield.py | 11 ++++------- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 70a3619fffb..069a0262f0d 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -108,7 +108,8 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.transforms is not None: sample = self.transforms(sample) - sample.update({x: features[x] for x in features if x != 'label'}) + + sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index eb3a1b4ddec..2acb307a96a 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -117,7 +117,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: if self.transforms is not None: sample = self.transforms(sample) - sample['magnitude'] = magnitude + + sample['magnitude'] = magnitude return sample diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 03bbf870edd..21a88170400 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -149,7 +149,8 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: if self.transforms is not None: sample = self.transforms(sample) - sample.update({x: features[x] for x in features if x != 'label'}) + + sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 4f9b2362b4e..56d351dc343 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -153,13 +153,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: if self.transforms is not None: sample = self.transforms(sample) - sample.update( - { - x: self.features[index][x] - for x in self.features[index] - if x != 'label' - } - ) + + sample.update( + {x: self.features[index][x] for x in self.features[index] if x != 'label'} + ) return sample