Skip to content

Commit

Permalink
Fix for classification & regression tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 14, 2024
1 parent 3ced3de commit 96f6b48
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
6 changes: 4 additions & 2 deletions torchgeo/datasets/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,19 @@ def __getitem__(self, index: int) -> dict[str, Any]:
Returns:
data, labels, field ids, and metadata at that index
"""
sample = {
features = {
'relative_time': torch.tensor(self.features.iat[index, 2]),
'ocean': torch.tensor(self.features.iat[index, 3]),
'label': torch.tensor(self.labels.iat[index, 1]),
}

image_id = self.labels.iat[index, 0]
sample['image'] = self._load_image(image_id)
sample = {'image': self._load_image(image_id)}
sample['label'] = features['label']

if self.transforms is not None:
sample = self.transforms(sample)
sample.update({x: features[x] for x in features if x != 'label'})

return sample

Expand Down
5 changes: 4 additions & 1 deletion torchgeo/datasets/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def plot(
show_predictions = 'prediction' in sample

if show_mask:
mask = sample['mask'].numpy()
mask = sample['mask']
if mask.ndim == 3 and mask.shape[0] == 1:
mask = mask.squeeze(0)
mask = mask.numpy()
ncols += 1

if show_predictions:
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datasets/quakeset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
label = torch.tensor(self.data[index]['label'])
magnitude = torch.tensor(self.data[index]['magnitude'])

sample = {'image': image, 'label': label, 'magnitude': magnitude}
sample = {'image': image, 'label': label}

if self.transforms is not None:
sample = self.transforms(sample)
sample['magnitude'] = magnitude

return sample

Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datasets/skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
data and label at that index
"""
sample: dict[str, str | Tensor] = {'image': self._load_image(index)}
sample.update(self._load_features(index))
features = self._load_features(index)
sample['label'] = features['label']

if self.transforms is not None:
sample = self.transforms(sample)
sample.update({x: features[x] for x in features if x != 'label'})

return sample

Expand Down
9 changes: 8 additions & 1 deletion torchgeo/datasets/sustainbench_crop_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,17 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
data and label at that index
"""
sample: dict[str, Tensor] = {'image': self.images[index]}
sample.update(self.features[index])
sample['label'] = self.features[index]['label']

if self.transforms is not None:
sample = self.transforms(sample)
sample.update(
{
x: self.features[index][x]
for x in self.features[index]
if x != 'label'
}
)

return sample

Expand Down

0 comments on commit 96f6b48

Please sign in to comment.