Skip to content

Commit ce3d1d8

Browse files
cleanup
1 parent 8950739 commit ce3d1d8

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

returnn/torch/engine.py

+1
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,7 @@ def _load_model(self):
878878
checkpoint_state = torch.load(filename, map_location=self._device)
879879
if epoch is None:
880880
epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
881+
step = checkpoint_state.get("step", 1)
881882
print(f" epoch {epoch}, global train step {step}", file=log.v4)
882883
# The checkpoint was saved when the step was already increased (but not the epoch yet).
883884
# Restore the last step.

0 commit comments

Comments
 (0)