Skip to content

Commit 3c140c1

Browse files
committed
refactored
1 parent e3da99e commit 3c140c1

File tree

4 files changed

+199
-210
lines changed

4 files changed

+199
-210
lines changed

src/MaxText/configs/rl.yml

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -61,33 +61,30 @@ loss_algo: 'grpo' # grpo or gspo-token
6161
# Model-Specific Overrides (examples)
6262
# For Llama3.1-8B:
6363
# model_name: llama3.1-8b
64-
# tokenizer_path: meta-llama/Llama-3.1-8B-Instruct
64+
# HF tokenizer_path: meta-llama/Llama-3.1-8B-Instruct
6565
#
6666
# For Llama3.1-70B with Pathways:
6767
# model_name: llama3.1-70b
68-
# tokenizer_path: meta-llama/Llama-3.1-70B-Instruct
68+
# HF tokenizer_path: meta-llama/Llama-3.1-70B-Instruct
6969

70-
async_checkpointing: 'false'
71-
checkpoint_period: 5
72-
skip_jax_distributed_system: True
70+
# ====== MaxText configs ======
7371
weight_dtype: 'bfloat16'
7472
attention: 'dot_product'
7573
remat_policy: 'custom'
7674
decoder_layer_input: 'offload'
7775
query_proj: 'offload'
7876
key_proj: 'offload'
7977
value_proj: 'offload'
80-
# for vLLM
81-
hf_model_name: 'meta-llama/Llama-3.1-70B-Instruct'
78+
8279

8380
# ====== Training ======
8481
batch_size: 1
8582
# Increase `batch_size` and `MAX_STEPS` for better results.
86-
# num_batches = 3738
87-
num_batches = 4 # 200
83+
# num_batches: 3738
84+
num_batches: 4 # 200
8885
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
8986
# increased to a max. of 330 (if batch size is 4).
90-
num_test_batches = 5 # 200
87+
num_test_batches: 5 # 200
9188
train_fraction: 1.0
9289

9390
eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
@@ -96,6 +93,20 @@ num_epochs: 1 # can potentially train for more epochs
9693

9794
gradient_clipping_threshold: 0.1
9895

96+
# ====== Evaluation ======
97+
generation_configs:
98+
greedy:
99+
temperature: 0.01
100+
top_k: 1
101+
top_p: 1.0
102+
standard:
103+
temperature: 0.7
104+
top_k: 50
105+
top_p: 0.95
106+
liberal:
107+
temperature: 0.85
108+
top_k: 2000
109+
top_p: 1.0
99110

100111
# greedy
101112
eval_temperature: 0.01
@@ -108,13 +119,9 @@ eval_top_p: 1.0
108119

109120

110121
# ====== Inference ======
111-
# Important to keep a high-ish temperature for varied, diverse responses during
112-
# training.
113-
decode_sampling_temperature: 0.9
114-
decode_sampling_top_k: 50
115-
decode_sampling_nucleus_p: 1.0
116-
117122
# for vLLM
123+
hf_model_name: None
124+
118125
# === Generation during GRPO training ===
119126
# max Lengths for prompt and completion
120127
max_prefill_predict_length: 256
@@ -123,13 +130,15 @@ kv_cache_buffer: 256
123130
hbm_utilization_vllm: 0.72
124131
swap_space_vllm_gb: 2
125132
# Generation Configuration During Training
133+
# Important to keep a high-ish temperature for varied, diverse responses during
134+
# training.
126135
decode_sampling_temperature: 0.9
127-
decode_sampling_top_p: 1.0
128136
decode_sampling_top_k: 50
137+
decode_sampling_nucleus_p: 1.0
129138

130139
# ====== Checkpoint Configuration ======
131140
enable_checkpointing: True
132-
async_checkpointing: True
141+
async_checkpointing: False
133142
checkpoint_period: 50
134143
max_num_checkpoints_to_keep: 10
135144

@@ -162,40 +171,9 @@ template: |
162171
<start_of_turn>model
163172
164173
165-
# TODO: fix this
166-
# Dataset Configuration
167-
dataset_type: hf # Huggingface input pipeline
168-
hf_path: 'gsm8k'
169-
hf_data_split: 'main'
170-
hf_data_files: 'train'
171-
172-
173-
# Pathways Inference Configuration
174-
# For multi-host/multi-slice setups
175-
use_pathways_reshard: False
176-
inference_devices_per_replica: 4
177-
inference_replicas: 1
178-
179-
# Tokenizer Settings
180-
add_bos: False
181-
add_eos: False
182-
return_log_prob: True
183-
184-
# Performance and Memory
185-
weight_dtype: bfloat16
186-
dtype: bfloat16
187-
188-
# Splash Attention Block Sizes
189-
# Tuned for GRPO workloads
190-
sa_block_q: 128
191-
sa_block_kv: 128
192-
sa_block_kv_compute: 128
193-
sa_block_q_dkv: 128
194-
sa_block_kv_dkv: 128
195-
sa_block_kv_dkv_compute: 128
196-
sa_block_q_dq: 128
197-
sa_block_kv_dq: 128
198-
sa_use_fused_bwd_kernel: False
199-
sa_q_layout: "HEAD_DIM_MINOR"
200-
sa_k_layout: "HEAD_DIM_MINOR"
201-
sa_v_layout: "HEAD_DIM_MINOR"
174+
# # TODO(@mazumdera): fix this
175+
# # Dataset Configuration
176+
# dataset_type: hf # Huggingface input pipeline
177+
# hf_path: 'gsm8k'
178+
# hf_data_split: 'main'
179+
# hf_data_files: 'train'

src/MaxText/evaluate_rl.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
from pprint import pprint
2020
import re
21+
2122
import sys
2223

2324
from datetime import datetime
@@ -51,11 +52,6 @@
5152
from MaxText import rl_utils
5253

5354
# ## Evaluate
54-
#
55-
#
56-
# Before we train the model, let's evaluate the model on the test set so we can
57-
# see the improvement post training.
58-
#
5955
# We evaluate it in two ways:
6056
#
6157
# **Quantitative**
@@ -76,13 +72,10 @@
7672

7773

7874
def generate_responses(
79-
mt_config,
75+
tmvp_config,
8076
prompts,
8177
rl_cluster,
8278
num_passes=1,
83-
temperature=0.7,
84-
top_k=50,
85-
top_p=0.95,
8679
):
8780
"""
8881
Generate responses for a batch of prompts across multiple passes.
@@ -104,15 +97,15 @@ def generate_responses(
10497
responses = rl_cluster.rollout.generate(
10598
prompts,
10699
rollout_config=RolloutConfig(
107-
max_tokens_to_generate=mt_config.max_target_length,
108-
temperature=mt_config.eval_temperature,
109-
top_k=mt_config.eval_top_k,
110-
top_p=mt_config.eval_top_p,
100+
max_tokens_to_generate=tmvp_config.max_target_length,
101+
temperature=tmvp_config.eval_temperature,
102+
top_k=tmvp_config.eval_top_k,
103+
top_p=tmvp_config.eval_top_p,
111104
),
112105
)
113106
responses = responses.text
114107

115-
if mt_config.debug:
108+
if tmvp_config.debug:
116109
print(f"Pass {p+1}/{num_passes}, responses: {responses}")
117110

118111
for idx, response in enumerate(responses):
@@ -121,7 +114,7 @@ def generate_responses(
121114
return multiple_call_responses
122115

123116

124-
def score_responses(mt_config, question, responses, answer):
117+
def score_responses(tmvp_config, question, responses, answer):
125118
"""
126119
Score a set of responses for a single question.
127120
@@ -133,10 +126,10 @@ def score_responses(mt_config, question, responses, answer):
133126
Returns:
134127
Tuple of (is_correct, is_partially_correct, has_correct_format)
135128
"""
136-
match_format = rl_utils.get_match_format_regex(mt_config)
137-
match_numbers = rl_utils.get_match_numbers_regex(mt_config)
129+
match_format = rl_utils.get_match_format_regex(tmvp_config)
130+
match_numbers = rl_utils.get_match_numbers_regex(tmvp_config)
138131

139-
if DEBUG:
132+
if tmvp_config.debug:
140133
print("========================================")
141134
print(f"Evaluation Question: {question}")
142135
print(f"Evaluation Answer: {answer}")
@@ -151,7 +144,7 @@ def score_responses(mt_config, question, responses, answer):
151144
# Extract numerical response
152145
extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
153146

154-
if DEBUG:
147+
if tmvp_config.debug:
155148
print(f"Evaluation extracted_response: {extracted_response}")
156149

157150
# Check exact correctness
@@ -164,7 +157,7 @@ def score_responses(mt_config, question, responses, answer):
164157
if 0.9 <= ratio <= 1.1:
165158
is_partially_correct = True
166159
except Exception as e:
167-
if DEBUG:
160+
if tmvp_config.debug:
168161
print(f"Evaluation Exception: {e}")
169162
print("SKIPPED")
170163

@@ -180,12 +173,9 @@ def score_responses(mt_config, question, responses, answer):
180173

181174

182175
def evaluate(
183-
mt_config,
176+
tmvp_config,
184177
dataset,
185178
rl_cluster,
186-
temperature=0.7,
187-
top_k=50,
188-
top_p=0.95,
189179
num_passes=1,
190180
corr_lst=False,
191181
make_lst=False,
@@ -194,11 +184,9 @@ def evaluate(
194184
Computes accuracy and percentage of outputs matching the format.
195185
196186
Args:
187+
tmvp_config: Configuration object
197188
dataset: The evaluation dataset
198-
rl_cluster: Model cluster for generation
199-
temperature: Sampling temperature
200-
top_k: Top-k sampling parameter
201-
top_p: Top-p sampling parameter
189+
rl_cluster: Model cluster for generation.
202190
num_passes: Number of generation passes
203191
corr_lst: If True, only include correct responses in the list
204192
make_lst: If True, return a list of (question, answer, responses)
@@ -219,19 +207,16 @@ def evaluate(
219207

220208
# Generate responses for all prompts in the batch
221209
multiple_call_responses = generate_responses(
222-
mt_config=mt_config,
210+
tmvp_config=tmvp_config,
223211
prompts=prompts,
224212
rl_cluster=rl_cluster,
225213
num_passes=num_passes,
226-
temperature=temperature,
227-
top_k=top_k,
228-
top_p=top_p,
229214
)
230215

231216
# Score each question-answer pair
232217
for question, responses, answer in zip(questions, multiple_call_responses, answers):
233218
is_correct, is_partially_correct, has_correct_format = score_responses(
234-
mt_config=mt_config,
219+
tmvp_config=tmvp_config,
235220
question=question,
236221
responses=responses,
237222
answer=answer,

0 commit comments

Comments
 (0)