Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,21 +871,26 @@ def load(
path: Union[str, Path],
state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None,
strict: bool = True,
*,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.).

How and which processes load gets determined by the `strategy`.
This method must be called on all processes!

Args:
path: A path to where the file is located.
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
If no state is given, then the checkpoint will be returned in full.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
state: A dictionary of objects whose state will be restored in-place from the checkpoint. If ``None``,
the full checkpoint will be loaded and returned.
strict: Whether to enforce that the keys in ``state`` match the keys in the checkpoint.
weights_only: If ``True``, only model weights will be loaded. This is useful for loading checkpoints
that do not include optimizers, schedulers, or other non-tensor objects.
If ``None``, the default behavior of the underlying strategy is used.

Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
given, the full checkpoint will be returned.
The remaining items that were not restored into the given ``state`` dictionary. If ``state`` is ``None``,
the full checkpoint is returned.

Example::

Expand All @@ -895,11 +900,15 @@ def load(
# Load into existing objects
state = {"model": model, "optimizer": optimizer}
remainder = fabric.load("checkpoint.pth", state)
epoch = remainder.get("epoch", 0)

"""
unwrapped_state = _unwrap_objects(state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
remainder = self._strategy.load_checkpoint(
path=path,
state=unwrapped_state,
strict=strict,
weights_only=weights_only,
)
self.barrier()
if state is not None:
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
Expand Down
18 changes: 17 additions & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def test_load_wrapped_objects(setup, tmp_path):

expected_remainder = {"extra": "data"}

def mocked_load_checkpoint(path, state, strict):
def mocked_load_checkpoint(path, state, strict, **kwargs):
assert not isinstance(state["model"], _FabricModule)
assert not isinstance(state["optimizer"], _FabricOptimizer)
state.update({"int": 5, "dict": {"x": 1}})
Expand Down Expand Up @@ -1347,3 +1347,19 @@ def mock_signature(obj):
callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2")
finally:
lightning.fabric.fabric.inspect.signature = original_signature


def test_fabric_load_accepts_weights_only_false(tmp_path):
import torch

from lightning.fabric import Fabric

fabric = Fabric(accelerator="cpu")

ckpt = {"foo": 123}
path = tmp_path / "ckpt.pt"
torch.save(ckpt, path)

remainder = fabric.load(path, weights_only=False)

assert remainder["foo"] == 123