Skip to content

Commit cdb467e

Browse files
author
jeandut
committed
modifying _uppdate_from_checkpoint signature
Signed-off-by: jeandut <jean.du-terrail@owkin.com>
1 parent 325a43f commit cdb467e

2 files changed

Lines changed: 11 additions & 13 deletions

File tree

substrafl/algorithms/pytorch/torch_base_algo.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,14 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device:
244244
device = torch.device("cuda")
245245
return device
246246

247-
def _update_from_checkpoint(self, path: Path) -> dict:
247+
def _update_from_checkpoint(self, checkpoint: dict) -> None:
248248
"""Load the checkpoint and update the internal state
249249
from it.
250250
Pop the values from the checkpoint so that we can ensure that it is empty at the
251251
end, i.e. all the values have been used.
252252
253253
Args:
254-
path (pathlib.Path): path where the checkpoint is saved
254+
checkpoint (dict): the checkpoint is saved
255255
256256
Returns:
257257
dict: checkpoint
@@ -260,13 +260,11 @@ def _update_from_checkpoint(self, path: Path) -> dict:
260260
261261
.. code-block:: python
262262
263-
def _update_from_checkpoint(self, path: Path) -> dict:
264-
checkpoint = super()._update_from_checkpoint(path=path)
263+
def _update_from_checkpoint(self, checkpoint: dict) -> None:
264+
super()._update_from_checkpoint(checkpoint=checkpoint)
265265
self._strategy_specific_variable = checkpoint.pop("strategy_specific_variable")
266-
return checkpoint
266+
return
267267
"""
268-
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
269-
checkpoint = torch.load(path, map_location=self._device)
270268
self._model.load_state_dict(checkpoint.pop("model_state_dict"))
271269

272270
if self._optimizer is not None:
@@ -285,8 +283,6 @@ def _update_from_checkpoint(self, path: Path) -> dict:
285283
else:
286284
torch.cuda.set_rng_state(checkpoint.pop("torch_rng_state").to("cpu"))
287285

288-
return checkpoint
289-
290286
def load_local_state(self, path: Path) -> "TorchAlgo":
291287
"""Load the stateful arguments of this class.
292288
Child classes do not need to override that function.
@@ -297,7 +293,9 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
297293
Returns:
298294
TorchAlgo: The class with the loaded elements.
299295
"""
300-
checkpoint = self._update_from_checkpoint(path=path)
296+
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
297+
checkpoint = torch.load(path, map_location=self._device)
298+
self._update_from_checkpoint(checkpoint=checkpoint)
301299
assert len(checkpoint) == 0, f"Not all values from the checkpoint have been used: {checkpoint.keys()}"
302300
return self
303301

substrafl/algorithms/pytorch/torch_fed_pca_algo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _get_state_to_save(self) -> dict:
336336
)
337337
return checkpoint
338338

339-
def _update_from_checkpoint(self, path: Path) -> dict:
339+
def _update_from_checkpoint(self, checkpoint: dict) -> None:
340340
"""Load the checkpoint and update the internal state from it.
341341
342342
Pop the values from the checkpoint so that we can ensure that it is empty at the
@@ -350,7 +350,7 @@ def _update_from_checkpoint(self, path: Path) -> dict:
350350
Returns:
351351
dict: checkpoint
352352
"""
353-
checkpoint = super()._update_from_checkpoint(path)
353+
super()._update_from_checkpoint(checkpoint)
354354
self.local_mean = checkpoint.pop("mean")
355355
self.local_covmat = checkpoint.pop("covariance_matrix")
356-
return checkpoint
356+
return

0 commit comments

Comments
 (0)