diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 2bbeef18..65b18838 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -324,7 +324,9 @@ def save_single_file( ): # Note: metadata kwargs cannot contain any of: # (step, model) - save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth") + pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step)) + os.makedirs(pth_path, exist_ok=True) + save_name = os.path.join(pth_path, "consolidated.00.pth") save_time = time.time() with FSDP.state_dict_type( model,