diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 34a5d838fd33a..aa909630cfbd8 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `DistributedSamplerWrapper` not forwarding `set_epoch` to the underlying sampler ([#21454](https://github.com/Lightning-AI/pytorch-lightning/pull/21454)) + - Fixed DDP notebook CUDA fork check to allow passive initialization when CUDA is not actively used ([#21402](https://github.com/Lightning-AI/pytorch-lightning/pull/21402)) ### Removed diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index a5f9f7457862e..d57dcdde75475 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -372,6 +372,14 @@ def __iter__(self) -> Iterator: self.dataset.reset() return (self.dataset[index] for index in super().__iter__()) + @override + def set_epoch(self, epoch: int) -> None: + super().set_epoch(epoch) + # Forward set_epoch to the original sampler if it supports it + original_sampler = self.dataset._sampler + if hasattr(original_sampler, "set_epoch") and callable(original_sampler.set_epoch): + original_sampler.set_epoch(epoch) + def _suggested_max_num_threads(num_processes: int = 1) -> int: if num_processes < 1: diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 51c4b320d5525..532c7c6d19024 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -15,6 +15,7 @@ from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning.fabric.utilities.distributed import ( + DistributedSamplerWrapper, _destroy_dist_connection, _gather_all_tensors, _get_default_process_group_backend_for_device, @@ -274,3 +275,73 @@ def test_is_dtensor(monkeypatch): monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False) assert not _is_dtensor(Mock(spec=DTensor)) + + +class _CustomSampler(torch.utils.data.Sampler): + """A custom sampler for testing DistributedSamplerWrapper.""" + + def __init__(self, data_source, non_callable_set_epoch: bool = False): + self.data_source = data_source + if non_callable_set_epoch: + self.set_epoch = "not a method" # attribute exists but is not callable + + def __len__(self): + return len(self.data_source) + + def __iter__(self): + return iter(range(len(self.data_source))) + + +class _CustomSamplerWithSetEpoch(_CustomSampler): + """A custom sampler that tracks set_epoch calls for testing.""" + + def __init__(self, data_source): + super().__init__(data_source) + self.epoch = 0 + self.set_epoch_call_count = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + self.set_epoch_call_count += 1 + + +def test_distributed_sampler_wrapper_set_epoch(): + """Test that DistributedSamplerWrapper correctly handles set_epoch for various sampler types. + + Reproduces issue #21454: When a sampler is wrapped by DistributedSamplerWrapper, calling set_epoch on the wrapper + should forward the call to the underlying sampler if it supports the method. + + """ + data_source = list(range(100)) + + # Case 1: Sampler WITH set_epoch method - should forward the call + sampler_with_set_epoch = _CustomSamplerWithSetEpoch(data_source) + wrapper = DistributedSamplerWrapper(sampler_with_set_epoch, num_replicas=2, rank=0) + + assert sampler_with_set_epoch.epoch == 0 + assert sampler_with_set_epoch.set_epoch_call_count == 0 + + wrapper.set_epoch(5) + assert wrapper.epoch == 5 + assert sampler_with_set_epoch.epoch == 5, "set_epoch was not forwarded to the underlying sampler" + assert sampler_with_set_epoch.set_epoch_call_count == 1 + + wrapper.set_epoch(10) + assert wrapper.epoch == 10 + assert sampler_with_set_epoch.epoch == 10 + assert sampler_with_set_epoch.set_epoch_call_count == 2 + + # Case 2: Sampler WITHOUT set_epoch method - should not fail + sampler_without_set_epoch = _CustomSampler(data_source) + wrapper = DistributedSamplerWrapper(sampler_without_set_epoch, num_replicas=2, rank=0) + + wrapper.set_epoch(5) # Should not raise + assert wrapper.epoch == 5 + + # Case 3: Sampler with non-callable set_epoch attribute - should not fail or call it + sampler_non_callable = _CustomSampler(data_source, non_callable_set_epoch=True) + wrapper = DistributedSamplerWrapper(sampler_non_callable, num_replicas=2, rank=0) + + wrapper.set_epoch(5) # Should not raise + assert wrapper.epoch == 5 + assert sampler_non_callable.set_epoch == "not a method" # Should remain unchanged