Skip to content

Commit

Permalink
Explicit check for step when loading the state (#2992)
Browse files Browse the repository at this point in the history
* Explicit check

* Nit
  • Loading branch information
muellerzr authored Aug 6, 2024
1 parent 95edc68 commit d982751
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3159,7 +3159,8 @@ def _inner(folder):
map_location,
**load_model_func_kwargs,
)
self.step = override_attributes["step"]
if "step" in override_attributes:
self.step = override_attributes["step"]
custom_checkpoints = [
f for f in os.listdir(input_dir) if re.search(r"^custom_checkpoint_\d+\.pkl$", f) is not None
]
Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def load_accelerator_state(
# Random states
try:
states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
override_attributes["step"] = states["step"]
if "step" in states:
override_attributes["step"] = states["step"]
random.setstate(states["random_state"])
np.random.set_state(states["numpy_random_seed"])
torch.set_rng_state(states["torch_manual_seed"])
Expand Down

0 comments on commit d982751

Please sign in to comment.