diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index d1537307d319f..34a5d838fd33a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Exposed `weights_only` argument for loading checkpoints in `Fabric.load()` and `Fabric.load_raw()` ([#21470](https://github.com/Lightning-AI/pytorch-lightning/pull/21470)) ### Changed diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 43df846e8333c..937b79cc9ded2 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -871,8 +871,10 @@ def load( path: Union[str, Path], state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None, strict: bool = True, + *, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: - """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) + """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.). How and which processes load gets determined by the `strategy`. This method must be called on all processes! @@ -881,7 +883,12 @@ def load( path: A path to where the file is located. state: A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full. - strict: Whether to enforce that the keys in `state` match the keys in the checkpoint. + strict: Whether to enforce that the keys in ``state`` match the keys in the checkpoint. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The remaining items that were not restored into the given state dictionary. If no state dictionary is @@ -899,7 +906,12 @@ def load( """ unwrapped_state = _unwrap_objects(state) - remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict) + remainder = self._strategy.load_checkpoint( + path=path, + state=unwrapped_state, + strict=strict, + weights_only=weights_only, + ) self.barrier() if state is not None: # We need to unwrap objects (see above) but this creates a new dictionary. In-place updates @@ -911,7 +923,14 @@ def load( state[k] = unwrapped_state[k] return remainder - def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], strict: bool = True) -> None: + def load_raw( + self, + path: Union[str, Path], + obj: Union[nn.Module, Optimizer], + strict: bool = True, + *, + weights_only: Optional[bool] = None, + ) -> None: """Load the state of a module or optimizer from a single state-dict file. Use this for loading a raw PyTorch model checkpoint created without Fabric. @@ -923,10 +942,15 @@ def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], str obj: A :class:`~torch.nn.Module` or :class:`~torch.optim.Optimizer` instance. strict: Whether to enforce that the keys in the module's state-dict match the keys in the checkpoint. Does not apply to optimizers. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. """ obj = _unwrap_objects(obj) - self._strategy.load_checkpoint(path=path, state=obj, strict=strict) + self._strategy.load_checkpoint(path=path, state=obj, strict=strict, weights_only=weights_only) def launch(self, function: Callable[["Fabric"], Any] = _do_nothing, *args: Any, **kwargs: Any) -> Any: """Launch and initialize all the processes needed for distributed execution. diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index b368f626c3b11..e3f1389e23064 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -324,6 +324,11 @@ def load_checkpoint( - A :class:`~torch.optim.Optimizer` instance, if the checkpoint file contains a raw optimizer state. strict: Whether to enforce that the keys in `state` match the keys in the checkpoint. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The remaining items that were not restored into the given state dictionary. If no state dictionary is diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 8bfc5002de6a4..41f7c32ede71c 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -1110,7 +1110,7 @@ def test_load_wrapped_objects(setup, tmp_path): expected_remainder = {"extra": "data"} - def mocked_load_checkpoint(path, state, strict): + def mocked_load_checkpoint(path, state, strict, **kwargs): assert not isinstance(state["model"], _FabricModule) assert not isinstance(state["optimizer"], _FabricOptimizer) state.update({"int": 5, "dict": {"x": 1}}) @@ -1145,11 +1145,11 @@ def test_load_raw(): wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer) fabric.load_raw(path="path0", obj=model) - fabric.strategy.load_checkpoint.assert_called_with(path="path0", state=model, strict=True) + fabric.strategy.load_checkpoint.assert_called_with(path="path0", state=model, strict=True, weights_only=None) fabric.load_raw(path="path1", obj=wrapped_model, strict=False) - fabric.strategy.load_checkpoint.assert_called_with(path="path1", state=model, strict=False) + fabric.strategy.load_checkpoint.assert_called_with(path="path1", state=model, strict=False, weights_only=None) fabric.load_raw(path="path2", obj=wrapped_optimizer) - fabric.strategy.load_checkpoint.assert_called_with(path="path2", state=optimizer, strict=True) + fabric.strategy.load_checkpoint.assert_called_with(path="path2", state=optimizer, strict=True, weights_only=None) def test_barrier(): @@ -1347,3 +1347,44 @@ def mock_signature(obj): callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2") finally: lightning.fabric.fabric.inspect.signature = original_signature + + +def test_fabric_load_accepts_weights_only_false(tmp_path): + import torch + + from lightning.fabric import Fabric + + fabric = Fabric(accelerator="cpu") + + ckpt = {"foo": 123} + path = tmp_path / "ckpt.pt" + torch.save(ckpt, path) + + remainder = fabric.load(path, weights_only=False) + + assert remainder["foo"] == 123 + + +@pytest.mark.parametrize("weights_only", [None, False, True]) +def test_fabric_load_forwards_weights_only_to_strategy(weights_only): + """Test that `Fabric.load()` correctly forwards the weights_only argument to the strategy.""" + fabric = Fabric(accelerator="cpu") + fabric.strategy.load_checkpoint = Mock(return_value={}) + + fabric.load("path.pt", weights_only=weights_only) + fabric.strategy.load_checkpoint.assert_called_with( + path="path.pt", state=None, strict=True, weights_only=weights_only + ) + + +@pytest.mark.parametrize("weights_only", [None, False, True]) +def test_fabric_load_raw_forwards_weights_only_to_strategy(weights_only): + """Test that `Fabric.load_raw()` correctly forwards the weights_only argument to the strategy.""" + fabric = Fabric(accelerator="cpu") + fabric.strategy.load_checkpoint = Mock() + + model = torch.nn.Linear(2, 2) + fabric.load_raw("path.pt", model, weights_only=weights_only) + fabric.strategy.load_checkpoint.assert_called_with( + path="path.pt", state=model, strict=True, weights_only=weights_only + )