Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `DistributedSamplerWrapper` not forwarding `set_epoch` to the underlying sampler ([#21454](https://github.com/Lightning-AI/pytorch-lightning/pull/21454))

- Fixed DDP notebook CUDA fork check to allow passive initialization when CUDA is not actively used ([#21402](https://github.com/Lightning-AI/pytorch-lightning/pull/21402))

### Removed
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ def __iter__(self) -> Iterator:
self.dataset.reset()
return (self.dataset[index] for index in super().__iter__())

@override
def set_epoch(self, epoch: int) -> None:
super().set_epoch(epoch)
# Forward set_epoch to the original sampler if it supports it
original_sampler = self.dataset._sampler
if hasattr(original_sampler, "set_epoch") and callable(original_sampler.set_epoch):
original_sampler.set_epoch(epoch)


def _suggested_max_num_threads(num_processes: int = 1) -> int:
if num_processes < 1:
Expand Down
71 changes: 71 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
DistributedSamplerWrapper,
_destroy_dist_connection,
_gather_all_tensors,
_get_default_process_group_backend_for_device,
Expand Down Expand Up @@ -274,3 +275,73 @@ def test_is_dtensor(monkeypatch):

monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
assert not _is_dtensor(Mock(spec=DTensor))


class _CustomSampler(torch.utils.data.Sampler):
"""A custom sampler for testing DistributedSamplerWrapper."""

def __init__(self, data_source, non_callable_set_epoch: bool = False):
self.data_source = data_source
if non_callable_set_epoch:
self.set_epoch = "not a method" # attribute exists but is not callable

def __len__(self):
return len(self.data_source)

def __iter__(self):
return iter(range(len(self.data_source)))


class _CustomSamplerWithSetEpoch(_CustomSampler):
"""A custom sampler that tracks set_epoch calls for testing."""

def __init__(self, data_source):
super().__init__(data_source)
self.epoch = 0
self.set_epoch_call_count = 0

def set_epoch(self, epoch):
self.epoch = epoch
self.set_epoch_call_count += 1


def test_distributed_sampler_wrapper_set_epoch():
"""Test that DistributedSamplerWrapper correctly handles set_epoch for various sampler types.

Reproduces issue #21454: When a sampler is wrapped by DistributedSamplerWrapper, calling set_epoch on the wrapper
should forward the call to the underlying sampler if it supports the method.

"""
data_source = list(range(100))

# Case 1: Sampler WITH set_epoch method - should forward the call
sampler_with_set_epoch = _CustomSamplerWithSetEpoch(data_source)
wrapper = DistributedSamplerWrapper(sampler_with_set_epoch, num_replicas=2, rank=0)

assert sampler_with_set_epoch.epoch == 0
assert sampler_with_set_epoch.set_epoch_call_count == 0

wrapper.set_epoch(5)
assert wrapper.epoch == 5
assert sampler_with_set_epoch.epoch == 5, "set_epoch was not forwarded to the underlying sampler"
assert sampler_with_set_epoch.set_epoch_call_count == 1

wrapper.set_epoch(10)
assert wrapper.epoch == 10
assert sampler_with_set_epoch.epoch == 10
assert sampler_with_set_epoch.set_epoch_call_count == 2

# Case 2: Sampler WITHOUT set_epoch method - should not fail
sampler_without_set_epoch = _CustomSampler(data_source)
wrapper = DistributedSamplerWrapper(sampler_without_set_epoch, num_replicas=2, rank=0)

wrapper.set_epoch(5) # Should not raise
assert wrapper.epoch == 5

# Case 3: Sampler with non-callable set_epoch attribute - should not fail or call it
sampler_non_callable = _CustomSampler(data_source, non_callable_set_epoch=True)
wrapper = DistributedSamplerWrapper(sampler_non_callable, num_replicas=2, rank=0)

wrapper.set_epoch(5) # Should not raise
assert wrapper.epoch == 5
assert sampler_non_callable.set_epoch == "not a method" # Should remain unchanged
Loading