Skip to content

Commit 33e7070

Browse files
committed
More fixes
1 parent 0df2fc9 commit 33e7070

12 files changed

+93
-44
lines changed

tests/datasets/test_geo.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
NonGeoClassificationDataset,
2727
NonGeoDataset,
2828
RasterDataset,
29+
Sample,
2930
Sentinel2,
3031
UnionDataset,
3132
VectorDataset,
@@ -46,7 +47,7 @@ def __init__(
4647
self.res = res
4748
self.paths = paths or []
4849

49-
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
50+
def __getitem__(self, query: BoundingBox) -> Sample:
5051
hits = self.index.intersection(tuple(query), objects=True)
5152
hit = next(iter(hits))
5253
bounds = BoundingBox(*hit.bounds)
@@ -77,7 +78,7 @@ class CustomSentinelDataset(Sentinel2):
7778

7879

7980
class CustomNonGeoDataset(NonGeoDataset):
80-
def __getitem__(self, index: int) -> dict[str, int]:
81+
def __getitem__(self, index: int) -> Sample:
8182
return {'index': index}
8283

8384
def __len__(self) -> int:

tests/datasets/test_splits.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
from collections.abc import Sequence
55
from math import floor, isclose
6-
from typing import Any
76

87
import pytest
98
from rasterio.crs import CRS
109

1110
from torchgeo.datasets import (
1211
BoundingBox,
1312
GeoDataset,
13+
Sample,
1414
random_bbox_assignment,
1515
random_bbox_splitting,
1616
random_grid_cell_assignment,
@@ -49,7 +49,7 @@ def __init__(
4949
self._crs = crs
5050
self.res = res
5151

52-
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
52+
def __getitem__(self, query: BoundingBox) -> Sample:
5353
hits = self.index.intersection(tuple(query), objects=True)
5454
hit = next(iter(hits))
5555
return {'content': hit.object}

tests/datasets/test_utils.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from rasterio.crs import CRS
1717

18-
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
18+
from torchgeo.datasets import BoundingBox, DependencyNotFoundError, Sample
1919
from torchgeo.datasets.utils import (
2020
Executable,
2121
array_to_tensor,
@@ -381,13 +381,13 @@ def test_disambiguate_timestamp(
381381

382382
class TestCollateFunctionsMatchingKeys:
383383
@pytest.fixture(scope='class')
384-
def samples(self) -> list[dict[str, Any]]:
384+
def samples(self) -> list[Sample]:
385385
return [
386386
{'image': torch.tensor([1, 2, 0]), 'crs': CRS.from_epsg(2000)},
387387
{'image': torch.tensor([0, 0, 3]), 'crs': CRS.from_epsg(2001)},
388388
]
389389

390-
def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
390+
def test_stack_unbind_samples(self, samples: list[Sample]) -> None:
391391
sample = stack_samples(samples)
392392
assert sample['image'].size() == torch.Size([2, 3])
393393
assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0], [0, 0, 3]]))
@@ -398,13 +398,13 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
398398
assert torch.allclose(samples[i]['image'], new_samples[i]['image'])
399399
assert samples[i]['crs'] == new_samples[i]['crs']
400400

401-
def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
401+
def test_concat_samples(self, samples: list[Sample]) -> None:
402402
sample = concat_samples(samples)
403403
assert sample['image'].size() == torch.Size([6])
404404
assert torch.allclose(sample['image'], torch.tensor([1, 2, 0, 0, 0, 3]))
405405
assert sample['crs'] == CRS.from_epsg(2000)
406406

407-
def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
407+
def test_merge_samples(self, samples: list[Sample]) -> None:
408408
sample = merge_samples(samples)
409409
assert sample['image'].size() == torch.Size([3])
410410
assert torch.allclose(sample['image'], torch.tensor([1, 2, 3]))
@@ -413,13 +413,13 @@ def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
413413

414414
class TestCollateFunctionsDifferingKeys:
415415
@pytest.fixture(scope='class')
416-
def samples(self) -> list[dict[str, Any]]:
416+
def samples(self) -> list[Sample]:
417417
return [
418418
{'image': torch.tensor([1, 2, 0]), 'crs1': CRS.from_epsg(2000)},
419419
{'mask': torch.tensor([0, 0, 3]), 'crs2': CRS.from_epsg(2001)},
420420
]
421421

422-
def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
422+
def test_stack_unbind_samples(self, samples: list[Sample]) -> None:
423423
sample = stack_samples(samples)
424424
assert sample['image'].size() == torch.Size([1, 3])
425425
assert sample['mask'].size() == torch.Size([1, 3])
@@ -434,7 +434,7 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
434434
assert torch.allclose(samples[1]['mask'], new_samples[0]['mask'])
435435
assert samples[1]['crs2'] == new_samples[0]['crs2']
436436

437-
def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
437+
def test_concat_samples(self, samples: list[Sample]) -> None:
438438
sample = concat_samples(samples)
439439
assert sample['image'].size() == torch.Size([3])
440440
assert sample['mask'].size() == torch.Size([3])
@@ -443,7 +443,7 @@ def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
443443
assert sample['crs1'] == CRS.from_epsg(2000)
444444
assert sample['crs2'] == CRS.from_epsg(2001)
445445

446-
def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
446+
def test_merge_samples(self, samples: list[Sample]) -> None:
447447
sample = merge_samples(samples)
448448
assert sample['image'].size() == torch.Size([3])
449449
assert sample['mask'].size() == torch.Size([3])

tests/samplers/test_batch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from rasterio.crs import CRS
1111
from torch.utils.data import DataLoader
1212

13-
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
13+
from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples
1414
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units
1515

1616

@@ -32,7 +32,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
3232
self._crs = crs
3333
self.res = res
3434

35-
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
35+
def __getitem__(self, query: BoundingBox) -> Sample:
3636
return {'index': query}
3737

3838

tests/samplers/test_single.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from rasterio.crs import CRS
1111
from torch.utils.data import DataLoader
1212

13-
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
13+
from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples
1414
from torchgeo.samplers import (
1515
GeoSampler,
1616
GridGeoSampler,
@@ -39,7 +39,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
3939
self._crs = crs
4040
self.res = res
4141

42-
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
42+
def __getitem__(self, query: BoundingBox) -> Sample:
4343
return {'index': query}
4444

4545

torchgeo/datasets/enviroatlas.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import sys
88
from collections.abc import Callable, Sequence
9-
from typing import ClassVar, cast
9+
from typing import Any, ClassVar, cast
1010

1111
import fiona
1212
import matplotlib.pyplot as plt
@@ -347,8 +347,8 @@ def __getitem__(self, query: BoundingBox) -> Sample:
347347
"""
348348
hits = self.index.intersection(tuple(query), objects=True)
349349
filepaths = cast(list[dict[str, str]], [hit.object for hit in hits])
350-
351-
sample: Sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query}
350+
images: list[np.typing.NDArray[Any]] = []
351+
masks: list[np.typing.NDArray[Any]] = []
352352

353353
if len(filepaths) == 0:
354354
raise IndexError(
@@ -389,23 +389,27 @@ def __getitem__(self, query: BoundingBox) -> Sample:
389389
'waterbodies',
390390
'water',
391391
]:
392-
sample['image'].append(data)
392+
images.append(data)
393393
elif layer in ['prior', 'prior_no_osm_no_buildings']:
394394
if self.prior_as_input:
395-
sample['image'].append(data)
395+
images.append(data)
396396
else:
397-
sample['mask'].append(data)
397+
masks.append(data)
398398
elif layer in ['lc']:
399399
data = self.raw_enviroatlas_to_idx_map[data]
400-
sample['mask'].append(data)
400+
masks.append(data)
401401
else:
402402
raise IndexError(f'query: {query} spans multiple tiles which is not valid')
403403

404-
sample['image'] = np.concatenate(sample['image'], axis=0)
405-
sample['mask'] = np.concatenate(sample['mask'], axis=0)
404+
image = torch.from_numpy(np.concatenate(images, axis=0))
405+
mask = torch.from_numpy(np.concatenate(masks, axis=0))
406406

407-
sample['image'] = torch.from_numpy(sample['image'])
408-
sample['mask'] = torch.from_numpy(sample['mask'])
407+
sample: Sample = {
408+
'image': image,
409+
'mask': mask,
410+
'crs': self.crs,
411+
'bounds': query,
412+
}
409413

410414
if self.transforms is not None:
411415
sample = self.transforms(sample)

torchgeo/datasets/eurocrops.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import os
88
import pathlib
99
from collections.abc import Callable, Iterable
10-
from typing import Any
1110

1211
import fiona
1312
import matplotlib.pyplot as plt
1413
import numpy as np
1514
from matplotlib.figure import Figure
1615
from rasterio.crs import CRS
16+
from torch import Tensor
1717

1818
from .errors import DatasetNotFoundError
1919
from .geo import VectorDataset
@@ -247,9 +247,7 @@ def plot(
247247

248248
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4))
249249

250-
def apply_cmap(
251-
arr: 'np.typing.NDArray[Any]',
252-
) -> 'np.typing.NDArray[np.float64]':
250+
def apply_cmap(arr: Tensor) -> 'np.typing.NDArray[np.float64]':
253251
# Color 0 as black, while applying default color map for the class indices.
254252
cmap = plt.get_cmap('viridis')
255253
im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map))

torchgeo/datasets/fair1m.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def __getitem__(self, index: int) -> Sample:
283283
label_path = label_path.replace('.tif', '.xml')
284284
voc = parse_pascal_voc(label_path)
285285
boxes, labels = self._load_target(voc['points'], voc['labels'])
286-
sample: Sample = {'image': image, 'boxes': boxes, 'label': labels}
286+
sample = {'image': image, 'boxes': boxes, 'label': labels}
287287

288288
if self.transforms is not None:
289289
sample = self.transforms(sample)

torchgeo/datasets/gid15.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __getitem__(self, index: int) -> Sample:
139139
mask = self._load_target(files['mask'])
140140
sample: Sample = {'image': image, 'mask': mask}
141141
else:
142-
sample: Sample = {'image': image}
142+
sample = {'image': image}
143143

144144
if self.transforms is not None:
145145
sample = self.transforms(sample)

torchgeo/datasets/skippd.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __len__(self) -> int:
134134

135135
return num_datapoints
136136

137-
def __getitem__(self, index: int) -> dict[str, str | Tensor]:
137+
def __getitem__(self, index: int) -> Sample:
138138
"""Return an index within the dataset.
139139
140140
Args:
@@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
143143
Returns:
144144
data and label at that index
145145
"""
146-
sample: dict[str, str | Tensor] = {'image': self._load_image(index)}
146+
sample: Sample = {'image': self._load_image(index)}
147147
sample.update(self._load_features(index))
148148

149149
if self.transforms is not None:
@@ -176,7 +176,7 @@ def _load_image(self, index: int) -> Tensor:
176176
tensor = torch.from_numpy(arr).to(torch.float32)
177177
return tensor
178178

179-
def _load_features(self, index: int) -> dict[str, str | Tensor]:
179+
def _load_features(self, index: int) -> Sample:
180180
"""Load label.
181181
182182
Args:
@@ -194,7 +194,7 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]:
194194
path = os.path.join(self.root, f'times_{self.split}_{self.task}.npy')
195195
datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat)
196196

197-
features: dict[str, str | Tensor] = {
197+
features: Sample = {
198198
'label': torch.tensor(label, dtype=torch.float32),
199199
'date': datestring,
200200
}

torchgeo/datasets/sustainbench_crop_yield.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self._verify()
9999

100100
self.images = []
101-
self.features = []
101+
self.features: list[Sample] = []
102102

103103
for country in self.countries:
104104
image_file_path = os.path.join(
@@ -122,7 +122,7 @@ def __init__(
122122
year = year_npz_file[idx]
123123
ndvi = ndvi_npz_file[idx]
124124

125-
features = {
125+
features: Sample = {
126126
'label': torch.tensor(target).to(torch.float32),
127127
'year': torch.tensor(int(year)),
128128
'ndvi': torch.from_numpy(ndvi).to(dtype=torch.float32),

torchgeo/datasets/utils.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,54 @@ class Sample(TypedDict, total=False):
6666
bounds: BoundingBox
6767
crs: CRS
6868

69-
# TODO: remove
69+
# TODO: Additional dataset-specific keys that should be subclasses
70+
images: Tensor
71+
input: Tensor
7072
boxes: Tensor
73+
bboxes: Tensor
74+
masks: Tensor
75+
labels: Tensor
76+
prediction_masks: Tensor
77+
prediction_boxes: Tensor
78+
prediction_labels: Tensor
79+
prediction_label: Tensor
80+
prediction_scores: Tensor
81+
audio: Tensor
82+
points: Tensor
83+
x: Tensor
84+
y: Tensor
85+
relative_time: Tensor
86+
ocean: Tensor
87+
array: Tensor
88+
chm: Tensor
89+
hsi: Tensor
90+
las: Tensor
91+
image1: Tensor
92+
image2: Tensor
93+
crs1: Tensor
94+
crs2: Tensor
95+
magnitude: Tensor
96+
agb: Tensor
97+
key: Tensor
98+
patch: Tensor
99+
geometry: Tensor
100+
properties: Tensor
101+
id: int
102+
centroid_lat: Tensor
103+
centroid_lon: Tensor
104+
content: Tensor
105+
year: Tensor
106+
ndvi: Tensor
107+
filename: str
108+
category: str
109+
field_ids: Tensor
110+
tile_index: Tensor
111+
transform: Tensor
112+
src: Tensor
113+
dst: Tensor
114+
input_size: Tensor
115+
output_size: Tensor
116+
index: BoundingBox
71117

72118

73119
class Batch(Sample):
@@ -456,7 +502,7 @@ def stack_samples(samples: Iterable[Sample]) -> Batch:
456502
457503
.. versionadded:: 0.2
458504
"""
459-
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
505+
collated: Batch = _list_dict_to_dict_list(samples)
460506
for key, value in collated.items():
461507
if isinstance(value[0], Tensor):
462508
collated[key] = torch.stack(value)
@@ -476,7 +522,7 @@ def concat_samples(samples: Iterable[Sample]) -> Batch:
476522
477523
.. versionadded:: 0.2
478524
"""
479-
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
525+
collated: Batch = _list_dict_to_dict_list(samples)
480526
for key, value in collated.items():
481527
if isinstance(value[0], Tensor):
482528
collated[key] = torch.cat(value)
@@ -498,7 +544,7 @@ def merge_samples(samples: Iterable[Sample]) -> Batch:
498544
499545
.. versionadded:: 0.2
500546
"""
501-
collated: dict[Any, Any] = {}
547+
collated: Batch = {}
502548
for sample in samples:
503549
for key, value in sample.items():
504550
if key in collated and isinstance(value, Tensor):

0 commit comments

Comments
 (0)