Skip to content

Commit aebe183

Browse files
committed
resolved some ruff issues
1 parent d1f062f commit aebe183

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

tests/datasets/test_substation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def dataset(
2121
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
2222
) -> Generator[Substation, None, None]:
2323
"""Fixture for the Substation."""
24-
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation')
24+
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation')
2525

2626
yield Substation(
2727
root=root,
28-
bands=[1,2,3],
28+
bands=[1, 2, 3],
2929
use_timepoints=True,
3030
mask_2d=True,
3131
timepoint_aggregation='median',
@@ -35,45 +35,45 @@ def dataset(
3535
@pytest.mark.parametrize(
3636
'config',
3737
[
38-
{'bands': [1,2,3], 'use_timepoints': False, 'mask_2d': True},
38+
{'bands': [1, 2, 3], 'use_timepoints': False, 'mask_2d': True},
3939
{
40-
'bands': [1,2,3],
40+
'bands': [1, 2, 3],
4141
'use_timepoints': True,
4242
'timepoint_aggregation': 'concat',
4343
'num_of_timepoints': 4,
4444
'mask_2d': False,
4545
},
4646
{
47-
'bands': [1,2,3],
47+
'bands': [1, 2, 3],
4848
'use_timepoints': True,
4949
'timepoint_aggregation': 'median',
5050
'num_of_timepoints': 4,
5151
'mask_2d': True,
5252
},
5353
{
54-
'bands': [1,2,3],
54+
'bands': [1, 2, 3],
5555
'use_timepoints': True,
5656
'timepoint_aggregation': 'first',
5757
'num_of_timepoints': 4,
5858
'mask_2d': False,
5959
},
6060
{
61-
'bands': [1,2,3],
61+
'bands': [1, 2, 3],
6262
'use_timepoints': True,
6363
'timepoint_aggregation': 'random',
6464
'num_of_timepoints': 4,
6565
'mask_2d': True,
6666
},
67-
{'bands': [1,2,3], 'use_timepoints': False, 'mask_2d': False},
67+
{'bands': [1, 2, 3], 'use_timepoints': False, 'mask_2d': False},
6868
{
69-
'bands': [1,2,3],
69+
'bands': [1, 2, 3],
7070
'use_timepoints': False,
7171
'timepoint_aggregation': 'first',
7272
'num_of_timepoints': 4,
7373
'mask_2d': False,
7474
},
7575
{
76-
'bands': [1,2,3],
76+
'bands': [1, 2, 3],
7777
'use_timepoints': False,
7878
'timepoint_aggregation': 'random',
7979
'num_of_timepoints': 4,

torchgeo/datamodules/substation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
"""Substation datamodule."""
5+
16
from typing import Any
27

38
import torch
@@ -52,6 +57,7 @@ def __init__(
5257
color_transforms: Color transformations to apply to the image.
5358
image_resize: Resizing function for the image.
5459
mask_resize: Resizing function for the mask.
60+
num_of_timepoints: Number of timepoints to use in the dataset.
5561
**kwargs: Additional arguments passed to Substation.
5662
"""
5763
super().__init__(Substation, batch_size, num_workers, **kwargs)
@@ -86,7 +92,7 @@ def setup(self, stage: str) -> None:
8692
download=True,
8793
checksum=False,
8894
)
89-
95+
9096
generator = torch.Generator().manual_seed(0)
9197
total_len = len(dataset)
9298
val_len = int(total_len * self.val_split_pct)

torchgeo/datasets/substation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
use_timepoints: Whether to use multiple timepoints for each image.
6363
mask_2d: Whether to use a 2D mask.
6464
timepoint_aggregation: How to aggregate multiple timepoints.
65+
num_of_timepoints: Number of timepoints to use for each image.
6566
download: Whether to download the dataset if it is not found.
6667
checksum: Whether to verify the dataset after downloading.
6768
"""
@@ -126,7 +127,6 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
126127
mask = torch.from_numpy(mask).long()
127128
mask = mask.unsqueeze(dim=0)
128129

129-
130130
if self.mask_2d:
131131
mask_0 = 1.0 - mask
132132
mask = torch.concat([mask_0, mask], dim=0)

0 commit comments

Comments
 (0)