Skip to content

Commit 574f946

Browse files
authored
Exposes API for processing pretraining data (#672)
This commit enables the data processing code to create pre-training style datasets. The training loop is also updated to ingest pretraining-style datasets, where documents are chunked by some `block_size` and the chunks are then treated as independent and fully-unmasked samples.
1 parent 638a753 commit 574f946

File tree

12 files changed

+1546
-30
lines changed

12 files changed

+1546
-30
lines changed

README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The library now supports reasoning traces through the `reasoning_content` field
2525
- [Using the library](#using-the-library)
2626
- [Data format](#data-format)
2727
- [Reasoning content support](#reasoning-content-support-1)
28+
- [Continual pretraining mode](#continual-pretraining-mode)
2829
- [Documentation](#documentation)
2930
- [Learning about the training arguments](#learning-about-training-arguments)
3031
- [`TrainingArgs`](#trainingargs)
@@ -122,6 +123,46 @@ The library now supports an optional `reasoning_content` field in addition to th
122123
}
123124
```
124125

126+
## Continual pretraining mode
127+
128+
In addition to instruction tuning, the library can run document-style continual pretraining on raw text corpora.
129+
Enable this by supplying a block size when invoking `main_ds.py`:
130+
131+
```bash
132+
torchrun main_ds.py \
133+
--model_name_or_path mistralai/Mistral-7B-v0.1 \
134+
--data_path /data/documents.jsonl \
135+
--ckpt_output_dir ./checkpoints \
136+
--effective_batch_size 128 \
137+
--max_batch_len 60000 \
138+
--block-size 8192 \
139+
--document-column-name text # optional, defaults to "document"
140+
```
141+
142+
- `--block-size` (required) toggles continual pretraining and controls how many tokens are packed into each block.
143+
- `--document-column-name` (optional) specifies which JSONL field contains the raw document text.
144+
145+
The same options are available programmatically via `TrainingArgs.pretraining_config`:
146+
147+
```python
148+
from instructlab.training import TrainingArgs, PretrainingConfig
149+
150+
train_args = TrainingArgs(
151+
model_name_or_path="mistralai/Mistral-7B-v0.1",
152+
data_path="documents.jsonl",
153+
ckpt_output_dir="./checkpoints",
154+
max_seq_len=4096,
155+
max_batch_len=40000,
156+
effective_batch_size=128,
157+
pretraining_config=PretrainingConfig(
158+
block_size=2048,
159+
document_column_name="text", # optional
160+
),
161+
)
162+
```
163+
164+
When a pretraining config is provided, `process_documents_for_pretraining()` is invoked under the hood to tokenize raw documents before training.
165+
125166
**Standard message structure:**
126167

127168
```json
@@ -139,7 +180,7 @@ The library now supports an optional `reasoning_content` field in addition to th
139180
}
140181
```
141182

142-
#### Important Notes
183+
### Important Notes
143184

144185
1. **Automatic reasoning content processing**: If `reasoning_content` exists in a message, it will always be processed and unmasked as long as the message role is targeted for unmasking. This ensures that reasoning traces are properly included in the training data.
145186

src/instructlab/training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"FSDPOptions",
1111
"ShardingStrategies",
1212
"DistributedBackend",
13+
"PretrainingConfig",
1314
)
1415

1516
# First Party
@@ -23,6 +24,7 @@
2324
DistributedBackend,
2425
FSDPOptions,
2526
LoraOptions,
27+
PretrainingConfig,
2628
QuantizeDataType,
2729
ShardingStrategies,
2830
TorchrunArgs,

src/instructlab/training/accelerator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ def __init__(
6363
self.lr_scheduler = None
6464
if self.distributed_framework == DistributedBackend.DEEPSPEED:
6565
# Standard
66+
cpu_offload_optimizer_ratio = (
67+
self.deepspeed_cpu_offload_optimizer_ratio or 0.0
68+
)
6669
accel_args = {
6770
"deepspeed_plugin": self.get_ds_plugin(
6871
world_size=torch.distributed.get_world_size(),
6972
samples_per_gpu=samples_per_gpu,
7073
grad_accum=grad_accum,
7174
opts=DeepSpeedOptions(
7275
cpu_offload_optimizer=deepspeed_cpu_offload_optimizer,
73-
cpu_offload_optimizer_ratio=self.deepspeed_cpu_offload_optimizer_ratio,
76+
cpu_offload_optimizer_ratio=cpu_offload_optimizer_ratio,
7477
cpu_offload_optimizer_pin_memory=self.deepspeed_cpu_offload_optimizer_pin_memory,
7578
save_samples=save_samples,
7679
),

src/instructlab/training/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ class DataProcessArgs(BaseModel):
6767
model_config = ConfigDict(protected_namespaces=())
6868

6969

70+
class PretrainingConfig(BaseModel):
71+
"""
72+
Configuration for pretraining mode.
73+
"""
74+
75+
block_size: int = Field(
76+
description="Size of each block in tokens for pretraining datasets."
77+
)
78+
document_column_name: str = Field(
79+
default="document",
80+
description="Name of the column containing raw documents for pretraining.",
81+
)
82+
83+
7084
# public API
7185
class TorchrunArgs(BaseModel):
7286
"""
@@ -266,6 +280,14 @@ class TrainingArgs(BaseModel):
266280
# "last_epoch". This works alongside the '--checkpoint_at_epoch' flag.
267281
keep_last_checkpoint_only: Optional[bool] = False
268282

283+
pretraining_config: Optional[PretrainingConfig] = Field(
284+
default=None,
285+
description=(
286+
"Pretraining configuration. When provided, enables block-based sampling "
287+
"for raw document pretraining datasets."
288+
),
289+
)
290+
269291
# TODO(osilkin):
270292
# we are only exposing this here because `run_training` today is implicitly coupled
271293
# with `process_data`. Since we don't have a specific field for data processing arguments,

src/instructlab/training/data_process.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,10 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs):
412412
logger.info("Tokenizing the dataset with %s tokenizer...", args.model_path)
413413
data_with_input_ids = data.map(
414414
lambda x: {
415-
"input_ids": tokenizer.apply_chat_template(x["messages"], tokenize=True),
415+
# newer versions of transformers have `return_dict=True` by default
416+
"input_ids": tokenizer.apply_chat_template(
417+
x["messages"], tokenize=True, return_dict=False
418+
),
416419
"unmask": bool(x["unmask"]) if "unmask" in x else False,
417420
},
418421
num_proc=NUM_PROC,
@@ -687,7 +690,8 @@ def unmask_messages(
687690
if regions:
688691
message_regions_map[idx] = regions
689692

690-
input_ids = tokenizer.apply_chat_template(msgs_with_unmasking)
693+
# newer versions of transformers have `return_dict=True` by default
694+
input_ids = tokenizer.apply_chat_template(msgs_with_unmasking, return_dict=False)
691695

692696
# Get token IDs for all unmask tokens
693697
unmask_begin_token_id = tokenizer.encode(
@@ -1133,6 +1137,109 @@ def process_messages_into_input_ids(
11331137
save_dataset(final_dataset, data_output_path, num_cpu_procs)
11341138

11351139

1140+
def process_documents_for_pretraining(
1141+
data_path: str,
1142+
data_output_path: str,
1143+
model_path: str,
1144+
num_cpu_procs: int,
1145+
document_column_name: str = "document",
1146+
) -> None:
1147+
"""
1148+
Process raw documents for pretraining by tokenizing without chunking.
1149+
1150+
Outputs one JSONL record per document with only input_ids (no labels).
1151+
Blocking/chunking happens later during training.
1152+
1153+
Pattern: Each document → [BOS][tokens][EOS]
1154+
1155+
Args:
1156+
data_path: Path to input JSONL with {"document": "text"} format
1157+
data_output_path: Directory for processed data output
1158+
model_path: Path to model/tokenizer
1159+
num_cpu_procs: Number of parallel processes
1160+
document_column_name: Name of the column containing the documents
1161+
"""
1162+
ensure_can_write_to_directory(data_output_path)
1163+
1164+
# Load and validate dataset
1165+
try:
1166+
data = load_dataset("json", data_files=data_path, split="train")
1167+
except Exception as e:
1168+
raise ValueError(
1169+
"Malformed or missing data, please ensure your dataset is correctly formatted"
1170+
) from e
1171+
1172+
if data.num_rows == 0:
1173+
raise ValueError("The provided dataset is empty")
1174+
1175+
if document_column_name not in data.column_names:
1176+
raise ValueError(
1177+
f"Pretraining data must have '{document_column_name}' field. Found: {data.column_names}"
1178+
)
1179+
1180+
logger.info("Loading tokenizer from %s", model_path)
1181+
tokenizer = AutoTokenizer.from_pretrained(model_path)
1182+
1183+
if tokenizer.eos_token_id is None:
1184+
raise ValueError("Tokenizer must have an EOS token defined for pretraining")
1185+
1186+
logger.info("Tokenizing %d documents for pretraining...", data.num_rows)
1187+
1188+
# Tokenize each document: encode() adds BOS, then append EOS
1189+
def tokenize_document(sample):
1190+
input_ids = tokenizer.encode(
1191+
sample[document_column_name], add_special_tokens=True
1192+
)
1193+
1194+
# ensures eos token is present without double-adding it.
1195+
if input_ids[-1] != tokenizer.eos_token_id:
1196+
input_ids.append(tokenizer.eos_token_id)
1197+
1198+
return {
1199+
"input_ids": input_ids,
1200+
"len": len(input_ids),
1201+
}
1202+
1203+
# Filter out empty documents before tokenization
1204+
def filter_empty_documents(batch):
1205+
return [bool(doc) for doc in batch[document_column_name]]
1206+
1207+
filtered_data = data.filter(
1208+
filter_empty_documents,
1209+
batched=True,
1210+
num_proc=num_cpu_procs,
1211+
desc="Filtering empty documents",
1212+
)
1213+
1214+
dropped_count = data.num_rows - filtered_data.num_rows
1215+
if dropped_count > 0:
1216+
logger.info(f"Dropped {dropped_count:,} empty documents")
1217+
tokenized_data = filtered_data.map(
1218+
tokenize_document,
1219+
num_proc=num_cpu_procs,
1220+
desc="Tokenizing documents",
1221+
remove_columns=filtered_data.column_names,
1222+
)
1223+
1224+
# Calculate statistics
1225+
total_tokens = sum(tokenized_data["len"])
1226+
avg_tokens = total_tokens / len(tokenized_data)
1227+
logger.info(f"Processed {len(tokenized_data):,} documents")
1228+
logger.info(f"Total tokens: {total_tokens:,}")
1229+
logger.info(f"Average tokens per document: {avg_tokens:.1f}")
1230+
1231+
# Save to JSONL (one record per document)
1232+
os.makedirs(data_output_path, exist_ok=True)
1233+
output_file = Path(data_output_path) / "data.jsonl"
1234+
1235+
tokenized_data.to_json(
1236+
output_file, num_proc=num_cpu_procs, lines=True, orient="records"
1237+
)
1238+
1239+
logger.info(f"Saved tokenized documents to {output_file}")
1240+
logger.info("Note: Blocking into fixed-size chunks will happen during training")
1241+
1242+
11361243
def ensure_can_write_to_directory(output_dir: str) -> None:
11371244
"""
11381245
Ensure that we can write to the output directory.

0 commit comments

Comments
 (0)