Skip to content

Commit

Permalink
Fix chesapeake dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Jul 18, 2024
1 parent 11b06e0 commit da3a851
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 1 addition & 3 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
sample[key] = sample[key].float()
sample[key] = self.aug(sample[key])
sample[key] = sample[key].to(dtype)
# Kornia adds batch dimension
sample[key] = rearrange(sample[key], '() c h w -> c h w')
return sample


Expand Down Expand Up @@ -93,7 +91,7 @@ def __init__(
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = patch_size * 3
kwargs['transforms'] = _Transform(K.CenterCrop(patch_size))
kwargs['transforms'] = _Transform(K.CenterCrop(patch_size, keepdim=True))

super().__init__(
ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
sample['mask'] = np.concatenate(sample['mask'], axis=0)

sample['image'] = torch.from_numpy(sample['image']).float()
sample['mask'] = torch.from_numpy(sample['mask']).long()
sample['mask'] = torch.from_numpy(sample['mask']).long().squeeze(0)

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down

0 comments on commit da3a851

Please sign in to comment.