We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8950739 commit ce3d1d8Copy full SHA for ce3d1d8
returnn/torch/engine.py
@@ -878,6 +878,7 @@ def _load_model(self):
878
checkpoint_state = torch.load(filename, map_location=self._device)
879
if epoch is None:
880
epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
881
+ step = checkpoint_state.get("step", 1)
882
print(f" epoch {epoch}, global train step {step}", file=log.v4)
883
# The checkpoint was saved when the step was already increased (but not the epoch yet).
884
# Restore the last step.
0 commit comments