Skip to content

Commit

Permalink
resolved some ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuld committed Jan 19, 2025
1 parent d1f062f commit aebe183
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
20 changes: 10 additions & 10 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def dataset(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> Generator[Substation, None, None]:
"""Fixture for the Substation."""
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation')
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation')

yield Substation(
root=root,
bands=[1,2,3],
bands=[1, 2, 3],
use_timepoints=True,
mask_2d=True,
timepoint_aggregation='median',
Expand All @@ -35,45 +35,45 @@ def dataset(
@pytest.mark.parametrize(
'config',
[
{'bands': [1,2,3], 'use_timepoints': False, 'mask_2d': True},
{'bands': [1, 2, 3], 'use_timepoints': False, 'mask_2d': True},
{
'bands': [1,2,3],
'bands': [1, 2, 3],
'use_timepoints': True,
'timepoint_aggregation': 'concat',
'num_of_timepoints': 4,
'mask_2d': False,
},
{
'bands': [1,2,3],
'bands': [1, 2, 3],
'use_timepoints': True,
'timepoint_aggregation': 'median',
'num_of_timepoints': 4,
'mask_2d': True,
},
{
'bands': [1,2,3],
'bands': [1, 2, 3],
'use_timepoints': True,
'timepoint_aggregation': 'first',
'num_of_timepoints': 4,
'mask_2d': False,
},
{
'bands': [1,2,3],
'bands': [1, 2, 3],
'use_timepoints': True,
'timepoint_aggregation': 'random',
'num_of_timepoints': 4,
'mask_2d': True,
},
{'bands': [1,2,3], 'use_timepoints': False, 'mask_2d': False},
{'bands': [1, 2, 3], 'use_timepoints': False, 'mask_2d': False},
{
'bands': [1,2,3],
'bands': [1, 2, 3],
'use_timepoints': False,
'timepoint_aggregation': 'first',
'num_of_timepoints': 4,
'mask_2d': False,
},
{
'bands': [1,2,3],
'bands': [1, 2, 3],
'use_timepoints': False,
'timepoint_aggregation': 'random',
'num_of_timepoints': 4,
Expand Down
8 changes: 7 additions & 1 deletion torchgeo/datamodules/substation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Substation datamodule."""

from typing import Any

import torch
Expand Down Expand Up @@ -52,6 +57,7 @@ def __init__(
color_transforms: Color transformations to apply to the image.
image_resize: Resizing function for the image.
mask_resize: Resizing function for the mask.
num_of_timepoints: Number of timepoints to use in the dataset.
**kwargs: Additional arguments passed to Substation.
"""
super().__init__(Substation, batch_size, num_workers, **kwargs)
Expand Down Expand Up @@ -86,7 +92,7 @@ def setup(self, stage: str) -> None:
download=True,
checksum=False,
)

generator = torch.Generator().manual_seed(0)
total_len = len(dataset)
val_len = int(total_len * self.val_split_pct)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
use_timepoints: Whether to use multiple timepoints for each image.
mask_2d: Whether to use a 2D mask.
timepoint_aggregation: How to aggregate multiple timepoints.
num_of_timepoints: Number of timepoints to use for each image.
download: Whether to download the dataset if it is not found.
checksum: Whether to verify the dataset after downloading.
"""
Expand Down Expand Up @@ -126,7 +127,6 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
mask = torch.from_numpy(mask).long()
mask = mask.unsqueeze(dim=0)


if self.mask_2d:
mask_0 = 1.0 - mask
mask = torch.concat([mask_0, mask], dim=0)
Expand Down

0 comments on commit aebe183

Please sign in to comment.