Skip to content

Commit

Permalink
norm stats (#2176)
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh authored Jul 18, 2024
1 parent 44ce007 commit b44164e
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions torchgeo/datamodules/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ def __init__(
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.EuroSAT`.
"""
super().__init__(EuroSAT, batch_size, num_workers, **kwargs)

bands = kwargs.get('bands', EuroSAT.all_band_names)
self.mean = torch.tensor([MEAN[b] for b in bands])
self.std = torch.tensor([STD[b] for b in bands])
super().__init__(EuroSAT, batch_size, num_workers, **kwargs)


class EuroSATSpatialDataModule(NonGeoDataModule):
Expand All @@ -120,11 +119,10 @@ def __init__(
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.EuroSATSpatial`.
"""
super().__init__(EuroSATSpatial, batch_size, num_workers, **kwargs)

bands = kwargs.get('bands', EuroSAT.all_band_names)
self.mean = torch.tensor([SPATIAL_MEAN[b] for b in bands])
self.std = torch.tensor([SPATIAL_STD[b] for b in bands])
super().__init__(EuroSATSpatial, batch_size, num_workers, **kwargs)


class EuroSAT100DataModule(NonGeoDataModule):
Expand All @@ -146,8 +144,7 @@ def __init__(
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.EuroSAT100`.
"""
super().__init__(EuroSAT100, batch_size, num_workers, **kwargs)

bands = kwargs.get('bands', EuroSAT.all_band_names)
self.mean = torch.tensor([MEAN[b] for b in bands])
self.std = torch.tensor([STD[b] for b in bands])
super().__init__(EuroSAT100, batch_size, num_workers, **kwargs)

0 comments on commit b44164e

Please sign in to comment.