diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 6fc233dd407..7bd4d0ae165 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -39,7 +39,7 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.InriaAerialImageLabeling`. """ - super().__init__(InriaAerialImageLabeling, batch_size, num_workers, **kwargs) + super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs) self.patch_size = _to_tuple(patch_size)