-
Notifications
You must be signed in to change notification settings - Fork 367
Expand file tree
/
Copy pathconfig.py
More file actions
222 lines (205 loc) · 7.9 KB
/
config.py
File metadata and controls
222 lines (205 loc) · 7.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
Training configuration for GRPO trainer.
This module contains the TrainingConfig class which defines all training
parameters, model settings, and operational modes.
"""
from typing import List, Literal, Optional
import torch
from pydantic import BaseModel, Field
class TrainingConfig(BaseModel):
"""
Training configuration for GRPO trainer.
Supports three training modes:
- 'none' (legacy): Periodic checkpoint saves + vLLM restarts
- 'shared_vllm': Attach to vLLM's shared memory tensors, update in-place
- 'lora_only': Freeze base model, train LoRA adapters only
"""
# === Model Configuration ===
model_name: str = Field(..., description="Name of the base model to train")
# === Training Hyperparameters ===
lr: float = Field(1e-5, description="Learning rate for the optimizer")
training_steps: int = Field(10, description="Number of training steps")
batch_size: int = Field(2, description="Batch size for training")
seq_len: int = Field(2048, description="Sequence length for training")
gradient_accumulation_steps: int = Field(
32, description="Number of gradient accumulation steps"
)
warmup_steps: int = Field(
0,
description=(
"Number of initial optimizer steps for linear LR warmup. "
"0 disables warmup."
),
)
optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field(
"adamw_8bit",
description="Optimizer to use: 'adamw' (full precision), "
"'adamw_8bit' (8-bit states, requires bitsandbytes), "
"'adafactor' (Adafactor optimizer)",
)
adafactor_scale_parameter: bool = Field(
False,
description=(
"Whether to enable Adafactor scale_parameter behavior when using "
"optimizer='adafactor'."
),
)
adafactor_relative_step: bool = Field(
False,
description=(
"Whether to enable Adafactor relative_step behavior when using "
"optimizer='adafactor'."
),
)
# === GRPO/PPO Hyperparameters ===
clip_eps: float = Field(
0.2,
description=(
"PPO-style clipping epsilon. "
"Clips the importance sampling ratio to [1-eps, 1+eps]. "
"Prevents large policy updates that could destabilize training."
),
)
# === Device & Storage ===
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
)
save_path: str = Field(
"trained_model_checkpoints", description="Base path to save model checkpoints"
)
checkpoint_interval: int = Field(
3,
description=(
"Save checkpoint every N training steps. "
"Set to 0 to only save final checkpoint."
),
)
# === vLLM Server Configuration ===
vllm_restart_interval: int = Field(
3, description="Restart vLLM every N training steps (legacy mode)"
)
vllm_port: int = Field(9001, description="Port for the vLLM server")
vllm_gpu: Optional[int] = Field(
None,
description=(
"GPU ID for vLLM server (lora_restart/legacy modes). "
"If None, uses same GPU as trainer. Set different for parallelism."
),
)
vllm_gpu_memory_utilization: float = Field(
0.45, description="GPU memory utilization for vLLM server (0.0-1.0)"
)
max_model_len: int = Field(
4096, description="Maximum context length for vLLM server"
)
dtype: str = Field(
"bfloat16", description="Data type for model weights (bfloat16, float16, auto)"
)
# === Weights & Biases Configuration ===
use_wandb: bool = Field(
False, description="Whether to use Weights & Biases for logging"
)
wandb_project: Optional[str] = Field(None, description="Wandb project name")
wandb_group: Optional[str] = Field(None, description="Wandb group name")
# === Training Mode Configuration ===
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = (
Field(
"none",
description=(
"How to synchronize weights with inference server. "
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)
)
train_layer_indices: Optional[List[int]] = Field(
None,
description=(
"Optional list of transformer layer indices to train in shared/legacy "
"full-model modes. If None, all layers are trainable."
),
)
# === Distributed Training Configuration ===
trainer_rank: int = Field(
0, description="Rank of this trainer in the distributed group"
)
world_size: int = Field(1, description="Total processes in the distributed group")
init_method: str = Field(
"env://",
description=(
"PyTorch distributed init method URL. "
"Use 'env://' to read MASTER_ADDR/MASTER_PORT from environment, "
"or 'tcp://host:port' for explicit rendezvous."
),
)
num_inference_nodes: int = Field(
0,
description=(
"Number of inference nodes (vLLM servers) to coordinate with. "
"0 means single-node local mode."
),
)
# === LoRA Configuration ===
lora_r: int = Field(16, description="LoRA rank (dimension of low-rank matrices)")
lora_alpha: int = Field(32, description="LoRA alpha (scaling factor)")
lora_dropout: float = Field(0.05, description="Dropout probability for LoRA layers")
lora_target_modules: Optional[List[str]] = Field(
None,
description=(
"List of module names to apply LoRA to. "
"If None, defaults to ['q_proj', 'v_proj'] for most models."
),
)
lora_layer_indices: Optional[List[int]] = Field(
None,
description=(
"Optional list of transformer layer indices to apply LoRA to. "
"If None, applies LoRA to all matching layers."
),
)
# === Single-Copy Mode Configuration ===
single_copy: bool = Field(
False,
description=(
"Enable TRUE single-copy mode via CUDA IPC. "
"The trainer attaches to vLLM's model tensors directly, "
"meaning only ONE copy of the model exists in GPU memory. "
"Requires trainer and vLLM to be on the SAME GPU(s). "
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
),
)
vllm_config_path: Optional[str] = Field(
None,
description=(
"Explicit path to vllm_bridge_config.json. "
"If not provided, auto-detects from LOGDIR environment variable, "
"current directory, or /tmp/atropos_bridge. "
"This file is created by vLLM when VLLM_ENABLE_SHARED_WEIGHTS=1 "
"and contains CUDA IPC handles for single-copy mode."
),
)
# === Debug & Benchmark Flags ===
debug_loading: bool = Field(
False,
description=(
"Enable verbose debug output during model loading and IPC attachment. "
"Useful for diagnosing single-copy mode issues."
),
)
benchmark: bool = Field(
False,
description=(
"Enable benchmark timing output showing step time, sync time, "
"data fetch time, and GPU memory usage per step."
),
)
# === Atropos API Configuration ===
atropos_url: str = Field(
"http://localhost:8000",
description=(
"URL of the Atropos API server (environment server). "
"Default is http://localhost:8000. Change for concurrent tests."
),
)