Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added method chaining support to `LightningModule.freeze()` and `LightningModule.unfreeze()` by returning `self` ([#21469](https://github.com/Lightning-AI/pytorch-lightning/pull/21469))


### Deprecated
Expand Down
16 changes: 11 additions & 5 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,33 +1390,39 @@ def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
"""
optimizer.zero_grad()

def freeze(self) -> None:
def freeze(self) -> Self:
r"""Freeze all params for inference.

Example::
.. code-block:: python

model = MyLightningModule(...)
model.freeze()

Returns:
:class:`LightningModule` with all parameters frozen.

"""
for param in self.parameters():
param.requires_grad = False

self.eval()
return self.eval()

def unfreeze(self) -> None:
def unfreeze(self) -> Self:
"""Unfreeze all parameters for training.

.. code-block:: python

model = MyLightningModule(...)
model.unfreeze()

Returns:
:class:`LightningModule` self with all parameters unfrozen.

"""
for param in self.parameters():
param.requires_grad = True

self.train()
return self.train()

def _verify_is_manual_optimization(self, fn_name: str) -> None:
if self.automatic_optimization:
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,14 @@ def test_model_checkpoint_only_weights(tmp_path):

def test_model_freeze_unfreeze():
model = BoringModel()
model.freeze()
freeze_ret = model.freeze()
assert freeze_ret is model
assert not model.training
for param in model.parameters():
assert not param.requires_grad

model.unfreeze()
unfreeze_ret = model.unfreeze()
assert unfreeze_ret is model
assert model.training
for param in model.parameters():
assert param.requires_grad
Expand Down
Loading