Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `DistributedSamplerWrapper` not forwarding `set_epoch` to the underlying sampler ([#21454](https://github.com/Lightning-AI/pytorch-lightning/pull/21454))

- Fixed DDP notebook CUDA fork check to allow passive initialization when CUDA is not actively used ([#21402](https://github.com/Lightning-AI/pytorch-lightning/pull/21402))

### Removed
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ def __iter__(self) -> Iterator:
self.dataset.reset()
return (self.dataset[index] for index in super().__iter__())

@override
def set_epoch(self, epoch: int) -> None:
super().set_epoch(epoch)
# Forward set_epoch to the original sampler if it supports it
original_sampler = self.dataset._sampler
if hasattr(original_sampler, "set_epoch") and callable(original_sampler.set_epoch):
original_sampler.set_epoch(epoch)


def _suggested_max_num_threads(num_processes: int = 1) -> int:
if num_processes < 1:
Expand Down
71 changes: 71 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
DistributedSamplerWrapper,
_destroy_dist_connection,
_gather_all_tensors,
_get_default_process_group_backend_for_device,
Expand Down Expand Up @@ -274,3 +275,73 @@ def test_is_dtensor(monkeypatch):

monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
assert not _is_dtensor(Mock(spec=DTensor))


class _CustomSampler(torch.utils.data.Sampler):
"""A custom sampler for testing DistributedSamplerWrapper."""

def __init__(self, data_source, non_callable_set_epoch: bool = False):
self.data_source = data_source
if non_callable_set_epoch:
self.set_epoch = "not a method" # attribute exists but is not callable

def __len__(self):
return len(self.data_source)

def __iter__(self):
return iter(range(len(self.data_source)))


class _CustomSamplerWithSetEpoch(_CustomSampler):
"""A custom sampler that tracks set_epoch calls for testing."""

def __init__(self, data_source):
super().__init__(data_source)
self.epoch = 0
self.set_epoch_call_count = 0

def set_epoch(self, epoch):
self.epoch = epoch
self.set_epoch_call_count += 1


def test_distributed_sampler_wrapper_set_epoch():
"""Test that DistributedSamplerWrapper correctly handles set_epoch for various sampler types.

Reproduces issue #21454: When a sampler is wrapped by DistributedSamplerWrapper, calling set_epoch on the wrapper
should forward the call to the underlying sampler if it supports the method.

"""
data_source = list(range(100))

# Case 1: Sampler WITH set_epoch method - should forward the call
sampler_with_set_epoch = _CustomSamplerWithSetEpoch(data_source)
wrapper = DistributedSamplerWrapper(sampler_with_set_epoch, num_replicas=2, rank=0)

assert sampler_with_set_epoch.epoch == 0
assert sampler_with_set_epoch.set_epoch_call_count == 0

wrapper.set_epoch(5)
assert wrapper.epoch == 5
assert sampler_with_set_epoch.epoch == 5, "set_epoch was not forwarded to the underlying sampler"
assert sampler_with_set_epoch.set_epoch_call_count == 1

wrapper.set_epoch(10)
assert wrapper.epoch == 10
assert sampler_with_set_epoch.epoch == 10
assert sampler_with_set_epoch.set_epoch_call_count == 2

# Case 2: Sampler WITHOUT set_epoch method - should not fail
sampler_without_set_epoch = _CustomSampler(data_source)
wrapper = DistributedSamplerWrapper(sampler_without_set_epoch, num_replicas=2, rank=0)

wrapper.set_epoch(5) # Should not raise
assert wrapper.epoch == 5

# Case 3: Sampler with non-callable set_epoch attribute - should not fail or call it
sampler_non_callable = _CustomSampler(data_source, non_callable_set_epoch=True)
wrapper = DistributedSamplerWrapper(sampler_non_callable, num_replicas=2, rank=0)

wrapper.set_epoch(5) # Should not raise
assert wrapper.epoch == 5
assert sampler_non_callable.set_epoch == "not a method" # Should remain unchanged
Loading