From aebe18315db8cc2ba4c880a7a00b0f1fd42ff268 Mon Sep 17 00:00:00 2001 From: rijuld Date: Sun, 19 Jan 2025 03:55:56 -0500 Subject: [PATCH] resolved some ruff issues --- tests/datasets/test_substation.py | 20 ++++++++++---------- torchgeo/datamodules/substation.py | 8 +++++++- torchgeo/datasets/substation.py | 2 +- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/datasets/test_substation.py b/tests/datasets/test_substation.py index 026b7c30ed3..ba362a19540 100644 --- a/tests/datasets/test_substation.py +++ b/tests/datasets/test_substation.py @@ -21,11 +21,11 @@ def dataset( self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> Generator[Substation, None, None]: """Fixture for the Substation.""" - root = os.path.join(os.getcwd(), 'tests', 'data', 'substation') + root = os.path.join(os.getcwd(), 'tests', 'data', 'substation') yield Substation( root=root, - bands=[1,2,3], + bands=[1, 2, 3], use_timepoints=True, mask_2d=True, timepoint_aggregation='median', @@ -35,45 +35,45 @@ def dataset( @pytest.mark.parametrize( 'config', [ - {'bands': [1,2,3], 'use_timepoints': False, 'mask_2d': True}, + {'bands': [1, 2, 3], 'use_timepoints': False, 'mask_2d': True}, { - 'bands': [1,2,3], + 'bands': [1, 2, 3], 'use_timepoints': True, 'timepoint_aggregation': 'concat', 'num_of_timepoints': 4, 'mask_2d': False, }, { - 'bands': [1,2,3], + 'bands': [1, 2, 3], 'use_timepoints': True, 'timepoint_aggregation': 'median', 'num_of_timepoints': 4, 'mask_2d': True, }, { - 'bands': [1,2,3], + 'bands': [1, 2, 3], 'use_timepoints': True, 'timepoint_aggregation': 'first', 'num_of_timepoints': 4, 'mask_2d': False, }, { - 'bands': [1,2,3], + 'bands': [1, 2, 3], 'use_timepoints': True, 'timepoint_aggregation': 'random', 'num_of_timepoints': 4, 'mask_2d': True, }, - {'bands': [1,2,3], 'use_timepoints': False, 'mask_2d': False}, + {'bands': [1, 2, 3], 'use_timepoints': False, 'mask_2d': False}, { - 'bands': [1,2,3], + 'bands': [1, 2, 3], 'use_timepoints': False, 'timepoint_aggregation': 'first', 'num_of_timepoints': 4, 'mask_2d': False, }, { - 'bands': [1,2,3], + 'bands': [1, 2, 3], 'use_timepoints': False, 'timepoint_aggregation': 'random', 'num_of_timepoints': 4, diff --git a/torchgeo/datamodules/substation.py b/torchgeo/datamodules/substation.py index cb01205c419..37c72d6e952 100644 --- a/torchgeo/datamodules/substation.py +++ b/torchgeo/datamodules/substation.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Substation datamodule.""" + from typing import Any import torch @@ -52,6 +57,7 @@ def __init__( color_transforms: Color transformations to apply to the image. image_resize: Resizing function for the image. mask_resize: Resizing function for the mask. + num_of_timepoints: Number of timepoints to use in the dataset. **kwargs: Additional arguments passed to Substation. """ super().__init__(Substation, batch_size, num_workers, **kwargs) @@ -86,7 +92,7 @@ def setup(self, stage: str) -> None: download=True, checksum=False, ) - + generator = torch.Generator().manual_seed(0) total_len = len(dataset) val_len = int(total_len * self.val_split_pct) diff --git a/torchgeo/datasets/substation.py b/torchgeo/datasets/substation.py index 0a1d71591b7..b3d1e21b7cc 100644 --- a/torchgeo/datasets/substation.py +++ b/torchgeo/datasets/substation.py @@ -62,6 +62,7 @@ def __init__( use_timepoints: Whether to use multiple timepoints for each image. mask_2d: Whether to use a 2D mask. timepoint_aggregation: How to aggregate multiple timepoints. + num_of_timepoints: Number of timepoints to use for each image. download: Whether to download the dataset if it is not found. checksum: Whether to verify the dataset after downloading. """ @@ -126,7 +127,6 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: mask = torch.from_numpy(mask).long() mask = mask.unsqueeze(dim=0) - if self.mask_2d: mask_0 = 1.0 - mask mask = torch.concat([mask_0, mask], dim=0)