diff --git a/train.py b/train.py index 9a39d7a2..eb53b465 100644 --- a/train.py +++ b/train.py @@ -126,7 +126,7 @@ def create_train_val_dataloader( def load_resume_state(opt: dict[str, Any]): resume_state_path = None if opt["auto_resume"]: - state_path = Path("experiments") / opt["name"] / "training_states" + state_path = opt["path"]["training_states"] if Path.is_dir(state_path): states = list( scandir(state_path, suffix="state", recursive=False, full_path=False)