|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -# GRPO Configuration |
16 | | -# This config consolidates common parameters for GRPO training across different model sizes |
| 15 | +# RL Configuration |
| 16 | +# This config consolidates common parameters for RL training across different model sizes |
17 | 17 |
|
18 | 18 | base_config: "base.yml" |
19 | 19 |
|
20 | | -use_grpo: True |
21 | | -train_data_columns: 'prompt' |
| 20 | +# ====== Hardware ===== |
| 21 | +trainer_devices_fraction: 0.5 |
| 22 | +sampler_devices_fraction: 0.5 |
| 23 | +chips_per_vm: 4 # depends on hardware, for v5p this is 4 |
22 | 24 |
|
| 25 | +# ====== Debug ====== |
| 26 | +debug: True |
| 27 | + |
| 28 | +# ====== Reproducibility ====== |
| 29 | +data_shuffle_seed: 42 |
| 30 | +loss_algo: 'grpo' # grpo or gspo-token |
| 31 | + |
| 32 | +# ====== Checkpoint saving ====== |
| 33 | +save_interval_steps: 500 |
| 34 | +max_to_keep: 4 |
| 35 | + |
| 36 | +# ====== GRPO ====== |
| 37 | +# === Generation during GRPO training === |
| 38 | +max_prompt_length: 256 |
| 39 | +total_generation_steps: 768 |
| 40 | + |
| 41 | +# The number of times the policy generates multiple responses for a given prompt |
| 42 | +# within a single training step. This corresponds to `G` in Algorithm 1 in the |
| 43 | +# paper. The "group" in GRPO comes from here. |
| 44 | +num_generations: 2 |
| 45 | + |
| 46 | +# === other GRPO configs === |
| 47 | +# The number of iterations per batch (π in GRPO algo 1). |
| 48 | +num_iterations: 1 |
| 49 | + |
| 50 | +# The coefficient for the KL divergence penalty (π½) in the GRPO loss function. |
| 51 | +# Important to keep a high enough value for this, otherwise, the KL divergence |
| 52 | +# can increase unchecked. |
| 53 | +beta: 0.08 |
| 54 | +# Epsilon value for clipping (π in GRPO loss in paper). Similar to PPO, for |
| 55 | +# stable updates. |
| 56 | +epsilon: 0.2 |
| 57 | + |
| 58 | +# ====== Training ====== |
| 59 | + |
| 60 | +batch_size: 1 |
| 61 | +# Increase `batch_size` and `MAX_STEPS` for better results. |
| 62 | +# NUM_BATCHES = 3738 |
| 63 | +NUM_BATCHES = 4 # 200 |
| 64 | +# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be |
| 65 | +# increased to a max. of 330 (if batch size is 4). |
| 66 | +NUM_TEST_BATCHES = 5 # 200 |
| 67 | + |
| 68 | +EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. |
| 69 | +NUM_EPOCHS = 1 # can potentially train for more epochs |
| 70 | + |
| 71 | + |
| 72 | +# === AdamW, warmup, cosine scheduler === |
| 73 | +LEARNING_RATE = 3e-6 |
| 74 | +B1 = 0.9 |
| 75 | +B2 = 0.99 |
| 76 | +WEIGHT_DECAY = 0.1 |
| 77 | +# == Cosine decay with warmup scheduler == |
| 78 | +# Linearly increase learning rate from 0. to 5e-6 in the first 10% training |
| 79 | +# steps, and then gradually decrease the learning rate to 0 using cosine |
| 80 | +# scheduler. |
| 81 | +WARMUP_STEPS = int(0.1 * MAX_STEPS) |
| 82 | +# == Grad clipping == |
| 83 | +# Grad clipping to prevent large gradients. Found this |
| 84 | +# important to keep KL divergence in check. |
| 85 | +MAX_GRAD_NORM = 0.1 |
| 86 | + |
| 87 | + |
| 88 | +# ====== Inference ====== |
| 89 | +# Important to keep a high-ish temperature for varied, diverse responses during |
| 90 | +# training. |
| 91 | +# greedy search |
| 92 | +temperature: 0.01 |
| 93 | +top_p: 1.0 |
| 94 | +top_k: 1 |
| 95 | + |
| 96 | +# # some randomness |
| 97 | +# "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95}, |
| 98 | +# # liberal |
| 99 | +# "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0}, |
| 100 | + |
| 101 | +TRAINER_DEVICES_FRACTION = 0.5 |
| 102 | +SAMPLER_DEVICES_FRACTION = 0.5 |
| 103 | +HBM_UTILIZATION_VLLM = 0.72 |
| 104 | +SWAP_SPACE_VLLM_GB = 2 |
| 105 | + |
| 106 | + |
| 107 | +# ====== Reward ====== |
| 108 | +REWARD_EXACT_FORMAT_MATCH = 3.0 |
| 109 | +REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 |
| 110 | +REWARD_PARTIAL_FORMAT_MATCH = 0.5 |
| 111 | +REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5 |
| 112 | +REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25 |
| 113 | +PENALTY_INCORRECT_FORMAT = -0.5 |
| 114 | +PENALTY_INCORRECT_ANSWER = -1.0 |
| 115 | + |
| 116 | + |
| 117 | +# TODO: fix this |
23 | 118 | # Dataset Configuration |
24 | 119 | dataset_type: hf # Huggingface input pipeline |
25 | 120 | hf_path: 'gsm8k' |
@@ -56,8 +151,6 @@ decode_sampling_top_k: 50 |
56 | 151 | # Training Loop Configuration |
57 | 152 | steps: 100 |
58 | 153 | per_device_batch_size: 1 |
59 | | -eval_interval: 10 |
60 | | -eval_steps: 5 |
61 | 154 |
|
62 | 155 | # Checkpoint Configuration |
63 | 156 | enable_checkpointing: True |
|
0 commit comments