@@ -244,14 +244,14 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device:
244244 device = torch .device ("cuda" )
245245 return device
246246
247- def _update_from_checkpoint (self , path : Path ) -> dict :
247+ def _update_from_checkpoint (self , checkpoint : dict ) -> None :
248248 """Load the checkpoint and update the internal state
249249 from it.
250250 Pop the values from the checkpoint so that we can ensure that it is empty at the
251251 end, i.e. all the values have been used.
252252
253253 Args:
254- path (pathlib.Path): path where the checkpoint is saved
254+ checkpoint (dict): the checkpoint is saved
255255
256256 Returns:
257257 dict: checkpoint
@@ -260,13 +260,11 @@ def _update_from_checkpoint(self, path: Path) -> dict:
260260
261261 .. code-block:: python
262262
263- def _update_from_checkpoint(self, path: Path ) -> dict :
264- checkpoint = super()._update_from_checkpoint(path=path )
263+ def _update_from_checkpoint(self, checkpoint: dict ) -> None :
264+ super()._update_from_checkpoint(checkpoint=checkpoint )
265265 self._strategy_specific_variable = checkpoint.pop("strategy_specific_variable")
266- return checkpoint
266+ return
267267 """
268- assert path .is_file (), f'Cannot load the model - does not exist { list (path .parent .glob ("*" ))} '
269- checkpoint = torch .load (path , map_location = self ._device )
270268 self ._model .load_state_dict (checkpoint .pop ("model_state_dict" ))
271269
272270 if self ._optimizer is not None :
@@ -285,8 +283,6 @@ def _update_from_checkpoint(self, path: Path) -> dict:
285283 else :
286284 torch .cuda .set_rng_state (checkpoint .pop ("torch_rng_state" ).to ("cpu" ))
287285
288- return checkpoint
289-
290286 def load_local_state (self , path : Path ) -> "TorchAlgo" :
291287 """Load the stateful arguments of this class.
292288 Child classes do not need to override that function.
@@ -297,7 +293,9 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
297293 Returns:
298294 TorchAlgo: The class with the loaded elements.
299295 """
300- checkpoint = self ._update_from_checkpoint (path = path )
296+ assert path .is_file (), f'Cannot load the model - does not exist { list (path .parent .glob ("*" ))} '
297+ checkpoint = torch .load (path , map_location = self ._device )
298+ self ._update_from_checkpoint (checkpoint = checkpoint )
301299 assert len (checkpoint ) == 0 , f"Not all values from the checkpoint have been used: { checkpoint .keys ()} "
302300 return self
303301
0 commit comments