Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [post training] support save hf safetensor format checkpoint #845

Merged
merged 12 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6,434 changes: 6,434 additions & 0 deletions docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List

import torch
from safetensors.torch import save_file
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
REPO_ID_FNAME,
SUFFIXES_TO_NOT_COPY,
ModelType,
copy_files,
safe_torch_load,
)
from torchtune.utils._logging import get_logger

logger = get_logger("DEBUG")
Expand Down Expand Up @@ -75,9 +85,24 @@ def save_checkpoint(
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str = "meta",
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta":
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
self._save_hf_format_checkpoint(model_file_path, state_dict)
else:
raise ValueError(f"Unsupported checkpoint format: {format}")
return str(model_file_path)

def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True)

# copy the related files for inference
Expand Down Expand Up @@ -140,6 +165,76 @@ def save_checkpoint(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)

print("model_file_path", str(model_file_path))
def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())

# repo_id is necessary for when saving an adapter config, so its compatible with HF.
# This json file is produced and saved in the download step.
# contents are {"repo_id": "some_model/some_model_version"}
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

if training.ADAPTER_KEY in state_dict:
# TODO: saving it "as is" is a requirement because, if we only save with
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
# convert_weights.peft_to_tune. The .pt format is not needed, but
# it is an easy way to distinguish the adapters. Ideally we should save only one.
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict[training.ADAPTER_KEY], output_path)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)

return str(model_file_path)
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
dim=config["hidden_size"],
head_dim=config.get("head_dim", None),
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors")
save_file(
state_dict[training.ADAPTER_KEY],
output_path,
metadata={"format": "pt"},
)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
else:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)

if training.ADAPTER_CONFIG in state_dict:
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)

output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)

# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
# So its easy to run inference with the model using this epoch's checkpoint
copy_files(
self._checkpoint_dir.parent,
model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional
from typing import Literal, Optional

from pydantic import BaseModel


class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def model_checkpoint_dir(model) -> str:
self.checkpoint_dir = model_checkpoint_dir(model)

self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format

self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
Expand Down Expand Up @@ -419,6 +420,7 @@ async def save_checkpoint(self, epoch: int) -> str:
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
checkpoint_format=self._checkpoint_format,
)

async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -460,7 +462,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ distribution_spec:
providers:
inference:
- inline::meta-reference
- remote::ollama
eval:
- inline::meta-reference
scoring:
Expand All @@ -15,7 +16,6 @@ distribution_spec:
- inline::torchtune
datasetio:
- inline::localfs
- remote::huggingface
telemetry:
- inline::meta-reference
agents:
Expand Down
11 changes: 7 additions & 4 deletions llama_stack/templates/experimental-post-training/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ providers:
max_seq_len: 4096
checkpoint_dir: null
create_distributed_process_group: False
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
Expand All @@ -34,9 +38,6 @@ providers:
config:
openai_api_key: ${env.OPENAI_API_KEY:}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
Expand All @@ -47,7 +48,9 @@ providers:
post_training:
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {}
config: {
checkpoint_format: huggingface
}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
Expand Down
Loading