Skip to content

Commit 3689b58

Browse files
committed
finished training with hf models
1 parent b291ad6 commit 3689b58

15 files changed

+1875
-0
lines changed

torchtitan/experiments/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Training LLAMA with HF weights
2+
3+
This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan.
4+
5+
## Usage
6+
7+
### Install extra dependencies
8+
9+
```bash
10+
pip install -r extra_requirements.txt
11+
```
12+
13+
### Test loading HF weights
14+
15+
```bash
16+
pytest test_loading_hf_weights.py
17+
```
18+
19+
### Run training
20+
21+
```bash
22+
LOG_RANK=7 bash run_train.sh
23+
```
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Llama 3 is licensed under the LLAMA 3 Community License,
8+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
9+
10+
import torchtitan.experiments.train_llama_hf.model # noqa: F401
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Any, Callable, Optional
9+
10+
import torch
11+
12+
from datasets import Dataset, load_dataset
13+
from datasets.distributed import split_dataset_by_node
14+
from torch.distributed.checkpoint.stateful import Stateful
15+
from torch.utils.data import IterableDataset
16+
from transformers import PreTrainedTokenizerBase
17+
18+
from torchtitan.components.dataloader import ParallelAwareDataloader
19+
20+
from torchtitan.config_manager import JobConfig
21+
from torchtitan.tools.logging import logger
22+
23+
24+
def _load_c4_dataset(dataset_path: str):
25+
"""Load C4 dataset with default configuration."""
26+
return load_dataset(dataset_path, name="en", split="train", streaming=True)
27+
28+
29+
def _process_c4_text(sample: dict[str, Any]) -> str:
30+
"""Process C4 dataset sample text."""
31+
return sample["text"]
32+
33+
34+
@dataclass
35+
class DatasetConfig:
36+
path: str
37+
loader: Callable
38+
text_processor: Callable
39+
40+
41+
# Add your dataset here here - more information at docs/datasets.md
42+
DATASETS = {
43+
"c4": DatasetConfig(
44+
path="allenai/c4",
45+
loader=_load_c4_dataset,
46+
text_processor=_process_c4_text,
47+
),
48+
"c4_test": DatasetConfig(
49+
path="tests/assets/c4_test",
50+
loader=lambda path: load_dataset(path, split="train"),
51+
text_processor=_process_c4_text,
52+
),
53+
}
54+
55+
56+
def _validate_dataset(
57+
dataset_name: str, dataset_path: str = None
58+
) -> tuple[str, Callable, Callable]:
59+
"""Validate dataset name and path."""
60+
if dataset_name not in DATASETS:
61+
raise ValueError(
62+
f"Dataset {dataset_name} is not supported. "
63+
f"Supported datasets are: {list(DATASETS.keys())}"
64+
)
65+
66+
config = DATASETS[dataset_name]
67+
path = dataset_path or config.path
68+
logger.info(f"Preparing {dataset_name} dataset from {path}")
69+
return path, config.loader, config.text_processor
70+
71+
72+
class HuggingFaceDataset(IterableDataset, Stateful):
73+
def __init__(
74+
self,
75+
dataset_name: str,
76+
dataset_path: Optional[str],
77+
tokenizer: PreTrainedTokenizerBase,
78+
seq_len: int = 2048,
79+
dp_rank: int = 0,
80+
dp_world_size: int = 1,
81+
infinite: bool = False,
82+
) -> None:
83+
# Force lowercase for consistent comparison
84+
dataset_name = dataset_name.lower()
85+
86+
path, dataset_loader, text_processor = _validate_dataset(
87+
dataset_name, dataset_path
88+
)
89+
ds = dataset_loader(path)
90+
91+
self.dataset_name = dataset_name
92+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
93+
self._tokenizer = tokenizer
94+
self.seq_len = seq_len
95+
self.infinite = infinite
96+
self._text_processor = text_processor
97+
98+
# Variables for checkpointing
99+
self._sample_idx = 0
100+
self._all_tokens: list[int] = []
101+
102+
def _get_data_iter(self):
103+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
104+
return iter([])
105+
106+
it = iter(self._data)
107+
for _ in range(self._sample_idx):
108+
next(it)
109+
return it
110+
111+
def __iter__(self):
112+
max_buffer_token_len = 1 + self.seq_len
113+
114+
while True:
115+
for sample in self._get_data_iter():
116+
# Use the dataset-specific text processor
117+
sample_text = self._text_processor(sample)
118+
sample_tokens = self._tokenizer.encode(sample_text)
119+
self._all_tokens.extend(sample_tokens)
120+
self._sample_idx += 1
121+
122+
while len(self._all_tokens) >= max_buffer_token_len:
123+
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
124+
# update tokens to the remaining tokens
125+
self._all_tokens = self._all_tokens[max_buffer_token_len:]
126+
input = x[:-1]
127+
label = x[1:]
128+
# Add position IDs (0 to seq_len-1)
129+
position_ids = torch.arange(len(input), dtype=torch.long)
130+
yield input, label, position_ids
131+
132+
if not self.infinite:
133+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
134+
break
135+
else:
136+
# Reset offset for the next iteration
137+
self._sample_idx = 0
138+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
139+
140+
def load_state_dict(self, state_dict):
141+
self._sample_idx = state_dict["sample_idx"]
142+
self._all_tokens = state_dict["token_buffer"]
143+
144+
def state_dict(self):
145+
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
146+
147+
148+
def build_hf_dataloader(
149+
dp_world_size: int,
150+
dp_rank: int,
151+
tokenizer,
152+
job_config: JobConfig,
153+
infinite: bool = True,
154+
) -> ParallelAwareDataloader:
155+
"""Build a data loader for HuggingFace datasets."""
156+
dataset_name = job_config.training.dataset
157+
dataset_path = job_config.training.dataset_path
158+
batch_size = job_config.training.batch_size
159+
seq_len = job_config.training.seq_len
160+
161+
hf_ds = HuggingFaceDataset(
162+
dataset_name=dataset_name,
163+
dataset_path=dataset_path,
164+
tokenizer=tokenizer,
165+
seq_len=seq_len,
166+
dp_rank=dp_rank,
167+
dp_world_size=dp_world_size,
168+
infinite=infinite,
169+
)
170+
171+
return ParallelAwareDataloader(
172+
dataset=hf_ds,
173+
dp_rank=dp_rank,
174+
dp_world_size=dp_world_size,
175+
batch_size=batch_size,
176+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers >=4.49.0
2+
sentencepiece >=0.2.0

0 commit comments

Comments
 (0)