Skip to content

Conversation

@deependujha
Copy link
Collaborator

@deependujha deependujha commented Dec 24, 2025

What does this PR do?

Fixes #21447
fixes #21131

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

This PR adds an explicit opt-out mechanism for custom samplers to prevent Lightning from applying automatic shuffling during training.

Custom samplers can now set disable_auto_shuffle = True to indicate that they fully control iteration order and should not be wrapped or replaced by Lightning (for example, in DDP).

from torch.utils.data import Dataset, Sampler, DataLoader
import lightning as L
import torch.nn as nn
import torch

class IntegerDataset(Dataset):
    def __init__(self):
        self.data = [i for i in range(100)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"data": torch.tensor([self.data[idx]])}

class InOrderSampler(Sampler):
    def __init__(self, dataset):
        self.dataset = dataset
        self.disable_auto_shuffle = True  # <-------- opt out of auto shuffle

    def __iter__(self):   
       yield from range(len(self.dataset))

    def __len__(self):
        return len(self.dataset)
    
dataset = IntegerDataset()
sampler = InOrderSampler(dataset)
dataloader = DataLoader(dataset=dataset, batch_size=3, sampler=sampler)

class MyModule(L.LightningModule):

    def __init__(self):
        super().__init__()

        self.layer = nn.Linear(10, 10)

    def training_step(self, batch, batch_idx):
        print(batch)

        input = torch.randn(10, 10, device = self.layer.weight.device)
        output = self.layer(input)
        loss = nn.functional.mse_loss(output, input)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

if __name__ == "__main__":
    model = MyModule()

    trainer = L.Trainer(
        max_epochs=1,
        use_distributed_sampler=True,
    )
    trainer.fit(
        model=model,
        train_dataloaders=dataloader,
    )

Why is this needed?

During training, Lightning currently assumes that data should be shuffled and may override the sampler configuration. However, it is not possible to reliably infer whether a custom sampler expects shuffling or fully controls the iteration order, especially since custom samplers may derive from PyTorch’s built-in sampler classes.

Relying on type or behavior inspection is fragile and can lead to unintended reordering. Introducing an explicit opt-out property allows custom samplers to clearly signal that they manage ordering themselves, avoiding implicit assumptions and preserving intended behavior.

What changed?

  • Lightning respects disable_auto_shuffle = True on custom samplers during training
  • Added documentation describing the new opt-out behavior

Backward compatibility

This change is fully backward compatible. Existing samplers are unaffected unless they explicitly set disable_auto_shuffle = True.

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21449.org.readthedocs.build/en/21449/

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an opt-out mechanism for custom samplers to prevent Lightning from automatically applying shuffling during training. Custom samplers can now set disable_auto_shuffle = True as an instance attribute to indicate they fully control iteration order and should not be modified by Lightning's distributed sampling logic.

  • Adds support for disable_auto_shuffle attribute on custom samplers to preserve their intended iteration order
  • Updates the CHANGELOG to document this new feature
  • Adds documentation with an example showing how to use this feature

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
src/lightning/pytorch/trainer/connectors/data_connector.py Modifies _process_dataloader to check for disable_auto_shuffle attribute on samplers and respects it when determining whether to apply shuffling
src/lightning/pytorch/CHANGELOG.md Adds entry documenting the new feature under the "Added" section
docs/source-pytorch/common/trainer.rst Adds new subsection explaining the feature with a code example showing how to create a custom sampler that opts out of automatic shuffling
Comments suppressed due to low confidence (1)

src/lightning/pytorch/CHANGELOG.md:30

  • The CHANGELOG has a formatting issue. There are two "### Fixed" sections (lines 14 and 28), which creates an inconsistent structure. The CHANGELOG entries should be organized properly with each section appearing only once in the unreleased changes. Additionally, line 17 has a lone hyphen that should be removed. The structure should follow: Added, Fixed, Deprecated, Removed sections in that order without duplication.
### Fixed

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
-

### Deprecated

- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397))

### Removed

---
- Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398))

### Fixed

- Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395))

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Dec 24, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 79%. Comparing base (027455b) to head (e3070d7).
⚠️ Report is 10 commits behind head on master.
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (027455b) and HEAD (e3070d7). Click for more details.

HEAD has 4687 uploads less than BASE
Flag BASE (027455b) HEAD (e3070d7)
python3.10 108 3
cpu 1075 30
lightning 540 15
pytest 538 0
lightning_fabric 268 0
python3.12 321 9
python3.12.7 322 9
python3.11 216 6
python 108 3
pytorch_lightning 267 15
pytorch2.6 53 3
pytest-full 537 30
pytorch2.1 108 6
pytorch2.9 52 3
pytorch2.3 54 3
pytorch2.8 54 3
pytorch2.2.2 54 3
pytorch2.5.1 54 3
pytorch2.4.1 54 3
pytorch2.7 54 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21449     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       24059    24006     -53     
=========================================
- Hits        20862    18954   -1908     
- Misses       3197     5052   +1855     

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related has conflicts pl Generic label for PyTorch Lightning package

Projects

None yet

1 participant