Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 10 additions & 25 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,39 +871,24 @@ def load(
path: Union[str, Path],
state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None,
strict: bool = True,
*,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
"""Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) How and which
processes load gets determined by the `strategy`.

How and which processes load gets determined by the `strategy`.
This method must be called on all processes!

Args:
path: A path to where the file is located.
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
If no state is given, then the checkpoint will be returned in full.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.

Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
given, the full checkpoint will be returned.

Example::

# Load full checkpoint
checkpoint = fabric.load("checkpoint.pth")

# Load into existing objects
state = {"model": model, "optimizer": optimizer}
remainder = fabric.load("checkpoint.pth", state)
epoch = remainder.get("epoch", 0)

"""
unwrapped_state = _unwrap_objects(state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
remainder = self._strategy.load_checkpoint(
path=path,
state=unwrapped_state,
strict=strict,
weights_only=weights_only,
)
self.barrier()
if state is not None:
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
for k in list(unwrapped_state.keys()):
obj, _ = _unwrap_compiled(state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
Expand Down
18 changes: 17 additions & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def test_load_wrapped_objects(setup, tmp_path):

expected_remainder = {"extra": "data"}

def mocked_load_checkpoint(path, state, strict):
def mocked_load_checkpoint(path, state, strict, **kwargs):
assert not isinstance(state["model"], _FabricModule)
assert not isinstance(state["optimizer"], _FabricOptimizer)
state.update({"int": 5, "dict": {"x": 1}})
Expand Down Expand Up @@ -1347,3 +1347,19 @@ def mock_signature(obj):
callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2")
finally:
lightning.fabric.fabric.inspect.signature = original_signature


def test_fabric_load_accepts_weights_only_false(tmp_path):
import torch

from lightning.fabric import Fabric

fabric = Fabric(accelerator="cpu")

ckpt = {"foo": 123}
path = tmp_path / "ckpt.pt"
torch.save(ckpt, path)

remainder = fabric.load(path, weights_only=False)

assert remainder["foo"] == 123