-
Notifications
You must be signed in to change notification settings - Fork 367
Expand file tree
/
Copy pathgrpo.py
More file actions
67 lines (52 loc) · 2.44 KB
/
grpo.py
File metadata and controls
67 lines (52 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
"""
GRPO (Group Relative Policy Optimization) Trainer.
Supports four training modes:
- none (legacy): Periodic checkpoint saves + vLLM restarts
- shared_vllm: Single-copy mode with CUDA IPC weight sharing
- lora_only: LoRA adapter training with HTTP hot-swap (SLOW - needs --enforce-eager)
- lora_restart: LoRA training with vLLM restarts (FAST - CUDA graphs enabled)
Usage:
# Legacy mode (manages vLLM internally)
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct
# Shared vLLM mode (requires external vLLM with VLLM_ENABLE_SHARED_WEIGHTS=1)
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\
--weight-bridge-mode shared_vllm
# LoRA mode with HTTP hot-swap (SLOW - 13 TPS due to --enforce-eager)
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
# LoRA mode with vLLM restarts (FAST - 170 TPS with CUDA graphs)
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\
--weight-bridge-mode lora_restart --lora-r 16 --lora-alpha 32 \\
--vllm-restart-interval 3
"""
from .cli import config_from_args, parse_args
from .trainers import train_legacy, train_lora, train_lora_restart, train_shared_vllm
def main():
"""Main entry point for GRPO trainer."""
args = parse_args()
config = config_from_args(args)
print("\n" + "=" * 60)
print("GRPO TRAINER")
print("=" * 60)
print(f"Model: {config.model_name}")
print(f"Mode: {config.weight_bridge_mode}")
print(f"Training steps: {config.training_steps}")
print(f"GRPO: clip_eps={config.clip_eps}")
print(f"{'='*60}\n")
if config.weight_bridge_mode == "shared_vllm":
# Single-copy mode: attach to vLLM's weights, update in-place
train_shared_vllm(config)
elif config.weight_bridge_mode == "lora_only":
# LoRA mode: freeze base model, train adapters only (HTTP hot-swap)
# WARNING: This is SLOW (~13 TPS) because it requires --enforce-eager
train_lora(config)
elif config.weight_bridge_mode == "lora_restart":
# LoRA mode with vLLM restarts (FAST - uses CUDA graphs)
# Restarts vLLM every vllm_restart_interval steps with new adapter
train_lora_restart(config)
else:
# Legacy mode: periodic checkpoint saves + vLLM restarts
train_legacy(config)
if __name__ == "__main__":
main()