Skip to content

Add support for distributed checkpointing of HF safetensors with DCP #2851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions torchtune/training/checkpointing/_checkpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
get_merged_lora_ckpt,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer
from torchtune.training.checkpointing._checkpointer import (
DistributedCheckpointer,
FullModelHFCheckpointer,
)
from torchtune.training.memory import OptimizerInBackwardWrapper

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -232,8 +235,12 @@ def _save_checkpoint_sync(
"""
intermediate_checkpoint = epoch + 1 < training_progress.total_epochs
checkpointer = self._get_checkpointer()
is_not_distributed_checkpointer = not isinstance(
is_distributed_checkpointer = isinstance(
checkpointer, DistributedCheckpointer
) or (
isinstance(checkpointer, FullModelHFCheckpointer)
and checkpointer._enable_dcp
and checkpointer._intermediate_hf_dir_dcp is not None
)

# final dict passed onto the checkpointer
Expand All @@ -248,7 +255,7 @@ def _save_checkpoint_sync(
model_state_dict = {}
optim_state_dict = {}

if is_not_distributed_checkpointer and not single_device:
if not is_distributed_checkpointer and not single_device:
# this logic is needed because staging an async checkpoint needs cpu
# which is also used here to save a sync checkpoint that causes issues when
# occurring concurrently. We should wait for async checkpoint to clear
Expand All @@ -274,7 +281,7 @@ def _save_checkpoint_sync(
log.info(
f"Getting full model state dict took {time.perf_counter() - cp_start:.2f} secs"
)
elif is_not_distributed_checkpointer:
elif not is_distributed_checkpointer:
model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
else:
model_state_dict = model.state_dict()
Expand All @@ -284,7 +291,7 @@ def _save_checkpoint_sync(
log.info("Getting optimizer state dict...")
optim_start = time.perf_counter()

if is_not_distributed_checkpointer:
if not is_distributed_checkpointer:
# This check can be removed once we fully migrate over to ``OptimizerInBackward``
if isinstance(optimizer, OptimizerInBackwardWrapper):
for param, opt in optimizer.optim_map.items():
Expand Down Expand Up @@ -348,7 +355,7 @@ def _save_checkpoint_helper():

# Now that we have the model and optim state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if is_not_distributed_checkpointer and not single_device:
if not is_distributed_checkpointer and not single_device:
if self._is_rank_zero:
_save_checkpoint_helper()

Expand Down
29 changes: 27 additions & 2 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
the receipe state from a previous run. Default is False
enable_dcp (bool): If True, the checkpointer will load the checkpoint file using dcp checkpointing apis.
This is currently an experimental feature.
intermediate_hf_dir_dcp (Optional[str]): If enable_dcp is True, then the presence of this arg indicates that checkpoints
are to be saved without rank-0 checkpointing. This is the path where the shards of safetensors files will be saved,
before being consolidated to the output_dir.

Raises:
ValueError: If ther checkpoint_dir and output_dir are not on the same filesystem
Expand All @@ -420,6 +423,7 @@ def __init__(
safe_serialization: bool = True,
should_load_recipe_state: bool = False,
enable_dcp: bool = False,
intermediate_hf_dir_dcp: Optional[str] = None,
) -> None:
self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
Expand All @@ -432,6 +436,7 @@ def __init__(
self._checkpoint_dir = checkpoint_dir
self._model_type = ModelType[model_type]
self._enable_dcp = enable_dcp
self._intermediate_hf_dir_dcp = intermediate_hf_dir_dcp
self._fs, _ = url_to_fs(self._checkpoint_dir)
self._output_dir = output_dir

Expand Down Expand Up @@ -815,14 +820,34 @@ def save_checkpoint(
for fqn, filename in self._weight_map.items():
index = int(filename.split("-")[1])
fqn_to_file_index_mapping[fqn] = index

dist = True if self._intermediate_hf_dir_dcp else False
save_path = (
os.path.join(self._intermediate_hf_dir_dcp, f"epoch_{epoch}")
if self._intermediate_hf_dir_dcp
else os.path.join(self._output_dir, f"epoch_{epoch}")
)
consolidated_output_path = (
os.path.join(self._output_dir, f"epoch_{epoch}")
if self._intermediate_hf_dir_dcp
else None
)
if consolidated_output_path:
self._fs.mkdirs(self._intermediate_hf_dir_dcp, exist_ok=True)
self._fs.mkdirs(consolidated_output_path, exist_ok=True)

storage_writer = HuggingFaceStorageWriter(
path=os.path.join(self._output_dir, f"epoch_{epoch}"),
path=save_path,
fqn_to_index_mapping=fqn_to_file_index_mapping,
save_sharded=dist,
thread_count=10,
consolidated_output_path=consolidated_output_path,
thread_count_consolidation=10,
)
save(
state_dict=state_dict[training.MODEL_KEY],
storage_writer=storage_writer,
no_dist=True,
no_dist=not dist,
)
else:
# split the state_dict into separate dicts, one for each output checkpoint file
Expand Down
Loading