1818import os
1919from pprint import pprint
2020import re
21+
2122import sys
2223
2324from datetime import datetime
5152from MaxText import rl_utils
5253
5354# ## Evaluate
54- #
55- #
56- # Before we train the model, let's evaluate the model on the test set so we can
57- # see the improvement post training.
58- #
5955# We evaluate it in two ways:
6056#
6157# **Quantitative**
7672
7773
7874def generate_responses (
79- mt_config ,
75+ tmvp_config ,
8076 prompts ,
8177 rl_cluster ,
8278 num_passes = 1 ,
83- temperature = 0.7 ,
84- top_k = 50 ,
85- top_p = 0.95 ,
8679):
8780 """
8881 Generate responses for a batch of prompts across multiple passes.
@@ -104,15 +97,15 @@ def generate_responses(
10497 responses = rl_cluster .rollout .generate (
10598 prompts ,
10699 rollout_config = RolloutConfig (
107- max_tokens_to_generate = mt_config .max_target_length ,
108- temperature = mt_config .eval_temperature ,
109- top_k = mt_config .eval_top_k ,
110- top_p = mt_config .eval_top_p ,
100+ max_tokens_to_generate = tmvp_config .max_target_length ,
101+ temperature = tmvp_config .eval_temperature ,
102+ top_k = tmvp_config .eval_top_k ,
103+ top_p = tmvp_config .eval_top_p ,
111104 ),
112105 )
113106 responses = responses .text
114107
115- if mt_config .debug :
108+ if tmvp_config .debug :
116109 print (f"Pass { p + 1 } /{ num_passes } , responses: { responses } " )
117110
118111 for idx , response in enumerate (responses ):
@@ -121,7 +114,7 @@ def generate_responses(
121114 return multiple_call_responses
122115
123116
124- def score_responses (mt_config , question , responses , answer ):
117+ def score_responses (tmvp_config , question , responses , answer ):
125118 """
126119 Score a set of responses for a single question.
127120
@@ -133,10 +126,10 @@ def score_responses(mt_config, question, responses, answer):
133126 Returns:
134127 Tuple of (is_correct, is_partially_correct, has_correct_format)
135128 """
136- match_format = rl_utils .get_match_format_regex (mt_config )
137- match_numbers = rl_utils .get_match_numbers_regex (mt_config )
129+ match_format = rl_utils .get_match_format_regex (tmvp_config )
130+ match_numbers = rl_utils .get_match_numbers_regex (tmvp_config )
138131
139- if DEBUG :
132+ if tmvp_config . debug :
140133 print ("========================================" )
141134 print (f"Evaluation Question: { question } " )
142135 print (f"Evaluation Answer: { answer } " )
@@ -151,7 +144,7 @@ def score_responses(mt_config, question, responses, answer):
151144 # Extract numerical response
152145 extracted_response = guess .group (1 ) if (guess := match_numbers .search (response )) is not None else "-1000000"
153146
154- if DEBUG :
147+ if tmvp_config . debug :
155148 print (f"Evaluation extracted_response: { extracted_response } " )
156149
157150 # Check exact correctness
@@ -164,7 +157,7 @@ def score_responses(mt_config, question, responses, answer):
164157 if 0.9 <= ratio <= 1.1 :
165158 is_partially_correct = True
166159 except Exception as e :
167- if DEBUG :
160+ if tmvp_config . debug :
168161 print (f"Evaluation Exception: { e } " )
169162 print ("SKIPPED" )
170163
@@ -180,12 +173,9 @@ def score_responses(mt_config, question, responses, answer):
180173
181174
182175def evaluate (
183- mt_config ,
176+ tmvp_config ,
184177 dataset ,
185178 rl_cluster ,
186- temperature = 0.7 ,
187- top_k = 50 ,
188- top_p = 0.95 ,
189179 num_passes = 1 ,
190180 corr_lst = False ,
191181 make_lst = False ,
@@ -194,11 +184,9 @@ def evaluate(
194184 Computes accuracy and percentage of outputs matching the format.
195185
196186 Args:
187+ tmvp_config: Configuration object
197188 dataset: The evaluation dataset
198- rl_cluster: Model cluster for generation
199- temperature: Sampling temperature
200- top_k: Top-k sampling parameter
201- top_p: Top-p sampling parameter
189+ rl_cluster: Model cluster for generation.
202190 num_passes: Number of generation passes
203191 corr_lst: If True, only include correct responses in the list
204192 make_lst: If True, return a list of (question, answer, responses)
@@ -219,19 +207,16 @@ def evaluate(
219207
220208 # Generate responses for all prompts in the batch
221209 multiple_call_responses = generate_responses (
222- mt_config = mt_config ,
210+ tmvp_config = tmvp_config ,
223211 prompts = prompts ,
224212 rl_cluster = rl_cluster ,
225213 num_passes = num_passes ,
226- temperature = temperature ,
227- top_k = top_k ,
228- top_p = top_p ,
229214 )
230215
231216 # Score each question-answer pair
232217 for question , responses , answer in zip (questions , multiple_call_responses , answers ):
233218 is_correct , is_partially_correct , has_correct_format = score_responses (
234- mt_config = mt_config ,
219+ tmvp_config = tmvp_config ,
235220 question = question ,
236221 responses = responses ,
237222 answer = answer ,
0 commit comments