Skip to content

Commit e3da99e

Browse files
committed
create rl_utils
1 parent 3673f2d commit e3da99e

File tree

3 files changed

+577
-194
lines changed

3 files changed

+577
-194
lines changed

src/MaxText/evaluate_rl.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=bare-except, consider-using-generator
16+
17+
import functools
18+
import os
19+
from pprint import pprint
20+
import re
21+
import sys
22+
23+
from datetime import datetime
24+
from flax import nnx
25+
from flax.linen import partitioning as nn_partitioning
26+
import grain
27+
import humanize
28+
29+
30+
import jax
31+
from jax.sharding import Mesh
32+
import optax
33+
from orbax import checkpoint as ocp
34+
import tensorflow_datasets as tfds
35+
from tqdm.auto import tqdm
36+
from tunix.rl import rl_cluster as rl_cluster_lib
37+
from tunix.rl.rollout import base_rollout
38+
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
39+
from tunix.sft import metrics_logger
40+
41+
42+
from transformers import AutoTokenizer
43+
44+
from flax import linen as nn
45+
import numpy as np
46+
from etils import epath
47+
48+
from tunix.rl.rollout.base_rollout import RolloutConfig
49+
50+
from MaxText.globals import MAXTEXT_ASSETS_ROOT
51+
from MaxText import rl_utils
52+
53+
# ## Evaluate
54+
#
55+
#
56+
# Before we train the model, let's evaluate the model on the test set so we can
57+
# see the improvement post training.
58+
#
59+
# We evaluate it in two ways:
60+
#
61+
# **Quantitative**
62+
#
63+
# * **Answer Accuracy**: percentage of samples for which the model predicts the
64+
# correct final numerical answer
65+
# * **Answer (Partial) Accuracy**: percentage of samples for which the model
66+
# predicts a final numerical answer such that the \`model answer / answer\`
67+
# ratio lies between 0.9 and 1.1.
68+
# * **Format Accuracy**: percentage of samples for which the model outputs the
69+
# correct format, i.e., reasoning between the reasoning special tokens, and the
70+
# final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.
71+
#
72+
# **Qualitative**
73+
#
74+
# We'll also print outputs for a few given questions so that we can compare the generated output later.
75+
#
76+
77+
78+
def generate_responses(
79+
mt_config,
80+
prompts,
81+
rl_cluster,
82+
num_passes=1,
83+
temperature=0.7,
84+
top_k=50,
85+
top_p=0.95,
86+
):
87+
"""
88+
Generate responses for a batch of prompts across multiple passes.
89+
90+
Args:
91+
prompts: List of prompts to generate responses for
92+
rl_cluster: Model cluster for generation
93+
num_passes: Number of generation passes
94+
temperature: Sampling temperature
95+
top_k: Top-k sampling parameter
96+
top_p: Top-p sampling parameter
97+
98+
Returns:
99+
List of lists containing responses for each prompt across passes
100+
"""
101+
multiple_call_responses = [[] for _ in range(len(prompts))]
102+
103+
for p in range(num_passes):
104+
responses = rl_cluster.rollout.generate(
105+
prompts,
106+
rollout_config=RolloutConfig(
107+
max_tokens_to_generate=mt_config.max_target_length,
108+
temperature=mt_config.eval_temperature,
109+
top_k=mt_config.eval_top_k,
110+
top_p=mt_config.eval_top_p,
111+
),
112+
)
113+
responses = responses.text
114+
115+
if mt_config.debug:
116+
print(f"Pass {p+1}/{num_passes}, responses: {responses}")
117+
118+
for idx, response in enumerate(responses):
119+
multiple_call_responses[idx].append(response)
120+
121+
return multiple_call_responses
122+
123+
124+
def score_responses(mt_config, question, responses, answer):
125+
"""
126+
Score a set of responses for a single question.
127+
128+
Args:
129+
question: The evaluation question
130+
responses: List of generated responses for this question
131+
answer: The correct answer
132+
133+
Returns:
134+
Tuple of (is_correct, is_partially_correct, has_correct_format)
135+
"""
136+
match_format = rl_utils.get_match_format_regex(mt_config)
137+
match_numbers = rl_utils.get_match_numbers_regex(mt_config)
138+
139+
if DEBUG:
140+
print("========================================")
141+
print(f"Evaluation Question: {question}")
142+
print(f"Evaluation Answer: {answer}")
143+
print(f"Evaluation Responses: {responses}")
144+
print("========================================")
145+
146+
is_correct = False
147+
is_partially_correct = False
148+
has_correct_format = False
149+
150+
for response in responses:
151+
# Extract numerical response
152+
extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
153+
154+
if DEBUG:
155+
print(f"Evaluation extracted_response: {extracted_response}")
156+
157+
# Check exact correctness
158+
try:
159+
if float(extracted_response.strip()) == float(answer.strip()):
160+
is_correct = True
161+
162+
# Check partial correctness (within 10%)
163+
ratio = float(extracted_response.strip()) / float(answer.strip())
164+
if 0.9 <= ratio <= 1.1:
165+
is_partially_correct = True
166+
except Exception as e:
167+
if DEBUG:
168+
print(f"Evaluation Exception: {e}")
169+
print("SKIPPED")
170+
171+
# Check format correctness
172+
if match_format.search(response) is not None:
173+
has_correct_format = True
174+
175+
# Early exit if all criteria are met
176+
if is_correct and is_partially_correct and has_correct_format:
177+
break
178+
179+
return is_correct, is_partially_correct, has_correct_format
180+
181+
182+
def evaluate(
183+
mt_config,
184+
dataset,
185+
rl_cluster,
186+
temperature=0.7,
187+
top_k=50,
188+
top_p=0.95,
189+
num_passes=1,
190+
corr_lst=False,
191+
make_lst=False,
192+
):
193+
"""
194+
Computes accuracy and percentage of outputs matching the format.
195+
196+
Args:
197+
dataset: The evaluation dataset
198+
rl_cluster: Model cluster for generation
199+
temperature: Sampling temperature
200+
top_k: Top-k sampling parameter
201+
top_p: Top-p sampling parameter
202+
num_passes: Number of generation passes
203+
corr_lst: If True, only include correct responses in the list
204+
make_lst: If True, return a list of (question, answer, responses)
205+
206+
Returns:
207+
Tuple of statistics and optionally the response list
208+
"""
209+
response_lst = []
210+
corr = 0
211+
partially_corr = 0
212+
corr_format = 0
213+
total = 0
214+
215+
for batch in tqdm(dataset):
216+
answers = batch["answer"]
217+
questions = batch["question"]
218+
prompts = batch["prompts"]
219+
220+
# Generate responses for all prompts in the batch
221+
multiple_call_responses = generate_responses(
222+
mt_config=mt_config,
223+
prompts=prompts,
224+
rl_cluster=rl_cluster,
225+
num_passes=num_passes,
226+
temperature=temperature,
227+
top_k=top_k,
228+
top_p=top_p,
229+
)
230+
231+
# Score each question-answer pair
232+
for question, responses, answer in zip(questions, multiple_call_responses, answers):
233+
is_correct, is_partially_correct, has_correct_format = score_responses(
234+
mt_config=mt_config,
235+
question=question,
236+
responses=responses,
237+
answer=answer,
238+
)
239+
240+
# Update counters
241+
if is_correct:
242+
corr += 1
243+
if corr_lst and make_lst:
244+
response_lst.append((question, answer, responses))
245+
else:
246+
if not corr_lst and make_lst:
247+
response_lst.append((question, answer, responses))
248+
249+
if is_partially_correct:
250+
partially_corr += 1
251+
252+
if has_correct_format:
253+
corr_format += 1
254+
255+
total += 1
256+
257+
# Print progress every 10 items
258+
if total % 10 == 0:
259+
print(
260+
f"===> {corr=}, {total=}, {corr / total * 100=}, "
261+
f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
262+
)
263+
264+
# Prepare return values
265+
to_return = (
266+
corr,
267+
total,
268+
corr / total * 100,
269+
partially_corr / total * 100,
270+
corr_format / total * 100,
271+
)
272+
273+
if make_lst:
274+
return to_return, response_lst
275+
return to_return

0 commit comments

Comments
 (0)