Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Exposed `weights_only` argument for loading checkpoints in `Fabric.load()` and `Fabric.load_raw()` ([#21470](https://github.com/Lightning-AI/pytorch-lightning/pull/21470))

### Changed

Expand Down
34 changes: 29 additions & 5 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,10 @@ def load(
path: Union[str, Path],
state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None,
strict: bool = True,
*,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.).

How and which processes load gets determined by the `strategy`.
This method must be called on all processes!
Expand All @@ -881,7 +883,12 @@ def load(
path: A path to where the file is located.
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
If no state is given, then the checkpoint will be returned in full.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
strict: Whether to enforce that the keys in ``state`` match the keys in the checkpoint.
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
recommend using ``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
Expand All @@ -899,7 +906,12 @@ def load(

"""
unwrapped_state = _unwrap_objects(state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
remainder = self._strategy.load_checkpoint(
path=path,
state=unwrapped_state,
strict=strict,
weights_only=weights_only,
)
self.barrier()
if state is not None:
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
Expand All @@ -911,7 +923,14 @@ def load(
state[k] = unwrapped_state[k]
return remainder

def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], strict: bool = True) -> None:
def load_raw(
self,
path: Union[str, Path],
obj: Union[nn.Module, Optimizer],
strict: bool = True,
*,
weights_only: Optional[bool] = None,
) -> None:
"""Load the state of a module or optimizer from a single state-dict file.

Use this for loading a raw PyTorch model checkpoint created without Fabric.
Expand All @@ -923,10 +942,15 @@ def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], str
obj: A :class:`~torch.nn.Module` or :class:`~torch.optim.Optimizer` instance.
strict: Whether to enforce that the keys in the module's state-dict match the keys in the checkpoint.
Does not apply to optimizers.
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
recommend using ``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

"""
obj = _unwrap_objects(obj)
self._strategy.load_checkpoint(path=path, state=obj, strict=strict)
self._strategy.load_checkpoint(path=path, state=obj, strict=strict, weights_only=weights_only)

def launch(self, function: Callable[["Fabric"], Any] = _do_nothing, *args: Any, **kwargs: Any) -> Any:
"""Launch and initialize all the processes needed for distributed execution.
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def load_checkpoint(
- A :class:`~torch.optim.Optimizer` instance, if the checkpoint file contains a raw optimizer state.

strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
recommend using ``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
Expand Down
49 changes: 45 additions & 4 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def test_load_wrapped_objects(setup, tmp_path):

expected_remainder = {"extra": "data"}

def mocked_load_checkpoint(path, state, strict):
def mocked_load_checkpoint(path, state, strict, **kwargs):
assert not isinstance(state["model"], _FabricModule)
assert not isinstance(state["optimizer"], _FabricOptimizer)
state.update({"int": 5, "dict": {"x": 1}})
Expand Down Expand Up @@ -1145,11 +1145,11 @@ def test_load_raw():
wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer)

fabric.load_raw(path="path0", obj=model)
fabric.strategy.load_checkpoint.assert_called_with(path="path0", state=model, strict=True)
fabric.strategy.load_checkpoint.assert_called_with(path="path0", state=model, strict=True, weights_only=None)
fabric.load_raw(path="path1", obj=wrapped_model, strict=False)
fabric.strategy.load_checkpoint.assert_called_with(path="path1", state=model, strict=False)
fabric.strategy.load_checkpoint.assert_called_with(path="path1", state=model, strict=False, weights_only=None)
fabric.load_raw(path="path2", obj=wrapped_optimizer)
fabric.strategy.load_checkpoint.assert_called_with(path="path2", state=optimizer, strict=True)
fabric.strategy.load_checkpoint.assert_called_with(path="path2", state=optimizer, strict=True, weights_only=None)


def test_barrier():
Expand Down Expand Up @@ -1347,3 +1347,44 @@ def mock_signature(obj):
callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2")
finally:
lightning.fabric.fabric.inspect.signature = original_signature


def test_fabric_load_accepts_weights_only_false(tmp_path):
import torch

from lightning.fabric import Fabric

fabric = Fabric(accelerator="cpu")

ckpt = {"foo": 123}
path = tmp_path / "ckpt.pt"
torch.save(ckpt, path)

remainder = fabric.load(path, weights_only=False)

assert remainder["foo"] == 123


@pytest.mark.parametrize("weights_only", [None, False, True])
def test_fabric_load_forwards_weights_only_to_strategy(weights_only):
"""Test that `Fabric.load()` correctly forwards the weights_only argument to the strategy."""
fabric = Fabric(accelerator="cpu")
fabric.strategy.load_checkpoint = Mock(return_value={})

fabric.load("path.pt", weights_only=weights_only)
fabric.strategy.load_checkpoint.assert_called_with(
path="path.pt", state=None, strict=True, weights_only=weights_only
)


@pytest.mark.parametrize("weights_only", [None, False, True])
def test_fabric_load_raw_forwards_weights_only_to_strategy(weights_only):
"""Test that `Fabric.load_raw()` correctly forwards the weights_only argument to the strategy."""
fabric = Fabric(accelerator="cpu")
fabric.strategy.load_checkpoint = Mock()

model = torch.nn.Linear(2, 2)
fabric.load_raw("path.pt", model, weights_only=weights_only)
fabric.strategy.load_checkpoint.assert_called_with(
path="path.pt", state=model, strict=True, weights_only=weights_only
)
Loading