@@ -93,8 +93,7 @@ def extract_hash_answer(text: str) -> str | None:
9393 return text .split ("####" )[1 ].strip ()
9494
9595
96- def get_gsm8k_dataset (data_dir : str , split : str = "train" , batch_size : int = 1 ,
97- num_batches : int = 4 , seed : int = 42 ):
96+ def get_gsm8k_dataset (data_dir : str , split : str = "train" , batch_size : int = 1 , num_batches : int = 4 , seed : int = 42 ):
9897 """Load and process GSM8K dataset for GRPO training."""
9998 if grain is None :
10099 raise ImportError ("grain is required for dataset processing. Please install it." )
@@ -159,9 +158,9 @@ def setup_device_allocation(config, use_pathways: bool = False):
159158 devices = jax .devices ()
160159
161160 # Get device allocation parameters from config
162- trainer_devices_fraction = getattr (config , ' trainer_devices_fraction' , 0.5 )
163- sampler_devices_fraction = getattr (config , ' sampler_devices_fraction' , 0.5 )
164- chips_per_vm = getattr (config , ' chips_per_vm' , 4 )
161+ trainer_devices_fraction = getattr (config , " trainer_devices_fraction" , 0.5 )
162+ sampler_devices_fraction = getattr (config , " sampler_devices_fraction" , 0.5 )
163+ chips_per_vm = getattr (config , " chips_per_vm" , 4 )
165164
166165 num_vms = len (devices ) // chips_per_vm
167166
@@ -175,7 +174,7 @@ def setup_device_allocation(config, use_pathways: bool = False):
175174 num_trainer_devices = int (num_devices * trainer_devices_fraction )
176175 num_sampler_devices = int (num_devices * sampler_devices_fraction )
177176 trainer_devices = devices [:num_trainer_devices ]
178- sampler_devices = devices [num_devices - num_sampler_devices :]
177+ sampler_devices = devices [num_devices - num_sampler_devices :]
179178 else :
180179 # Not using Pathways OR single host - use all devices for both
181180 if use_pathways :
@@ -194,8 +193,7 @@ def match_format_exactly(prompts, completions, **kwargs):
194193 """Reward exact format matching."""
195194 scores = []
196195 match_format = re .compile (
197- rf"^[\s]{{0,}}" rf"{ REASONING_START } .+?{ REASONING_END } .*?"
198- rf"{ SOLUTION_START } (.+?){ SOLUTION_END } " rf"[\s]{{0,}}$" ,
196+ rf"^[\s]{{0,}}" rf"{ REASONING_START } .+?{ REASONING_END } .*?" rf"{ SOLUTION_START } (.+?){ SOLUTION_END } " rf"[\s]{{0,}}$" ,
199197 flags = re .MULTILINE | re .DOTALL ,
200198 )
201199
@@ -223,8 +221,7 @@ def match_format_approximately(prompts, completions, **kwargs):
223221def check_answer (prompts , completions , answer , ** kwargs ):
224222 """Reward correct answers."""
225223 match_format = re .compile (
226- rf"^[\s]{{0,}}" rf"{ REASONING_START } .+?{ REASONING_END } .*?"
227- rf"{ SOLUTION_START } (.+?){ SOLUTION_END } " rf"[\s]{{0,}}$" ,
224+ rf"^[\s]{{0,}}" rf"{ REASONING_START } .+?{ REASONING_END } .*?" rf"{ SOLUTION_START } (.+?){ SOLUTION_END } " rf"[\s]{{0,}}$" ,
228225 flags = re .MULTILINE | re .DOTALL ,
229226 )
230227
@@ -288,7 +285,7 @@ def grpo_train(config):
288285 print ("=" * 80 )
289286
290287 # Setup device allocation
291- use_pathways = getattr (config , ' use_pathways_reshard' , False )
288+ use_pathways = getattr (config , " use_pathways_reshard" , False )
292289 trainer_devices , sampler_devices , num_vms = setup_device_allocation (config , use_pathways )
293290
294291 print (f"Device allocation: { len (trainer_devices )} trainer, { len (sampler_devices )} sampler" )
@@ -305,15 +302,15 @@ def grpo_train(config):
305302 train_data_dir ,
306303 split = "train" ,
307304 batch_size = config .per_device_batch_size ,
308- num_batches = getattr (config , ' num_batches' , 4 )
305+ num_batches = getattr (config , " num_batches" , 4 ),
309306 )
310307
311308 # Load test dataset for evaluation (currently not used in training loop)
312309 get_gsm8k_dataset (
313310 test_data_dir ,
314311 split = "test" ,
315312 batch_size = config .per_device_batch_size ,
316- num_batches = getattr (config , ' num_test_batches' , 5 )
313+ num_batches = getattr (config , " num_test_batches" , 5 ),
317314 )
318315
319316 # Load reference model
@@ -335,8 +332,8 @@ def grpo_train(config):
335332 rollout_mesh = policy_mesh
336333
337334 # Setup optimizer
338- learning_rate = getattr (config , ' learning_rate' , 3e-6 )
339- max_steps = getattr (config , ' steps' , 100 )
335+ learning_rate = getattr (config , " learning_rate" , 3e-6 )
336+ max_steps = getattr (config , " steps" , 100 )
340337 warmup_steps = int (0.1 * max_steps )
341338
342339 optimizer = optax .adamw (
@@ -353,7 +350,7 @@ def grpo_train(config):
353350 )
354351
355352 # Add gradient clipping if specified
356- max_grad_norm = getattr (config , ' max_grad_norm' , 0.1 )
353+ max_grad_norm = getattr (config , " max_grad_norm" , 0.1 )
357354 if max_grad_norm is not None :
358355 optimizer = optax .chain (
359356 optax .clip_by_global_norm (max_norm = max_grad_norm ),
@@ -365,18 +362,14 @@ def grpo_train(config):
365362 os .makedirs (ckpt_dir , exist_ok = True )
366363
367364 checkpointing_options = ocp .CheckpointManagerOptions (
368- save_interval_steps = getattr (config , 'checkpoint_period' , 50 ),
369- max_to_keep = 4
365+ save_interval_steps = getattr (config , "checkpoint_period" , 50 ), max_to_keep = 4
370366 )
371367
372368 # Setup metrics logging
373369 log_dir = f"{ config .base_output_directory } /logs"
374370 os .makedirs (log_dir , exist_ok = True )
375371
376- metrics_logging_options = metrics_logger .MetricsLoggerOptions (
377- log_dir = log_dir ,
378- flush_every_n_steps = 20
379- )
372+ metrics_logging_options = metrics_logger .MetricsLoggerOptions (log_dir = log_dir , flush_every_n_steps = 20 )
380373
381374 # Setup RL cluster config
382375 cluster_config = rl_cluster_lib .ClusterConfig (
@@ -389,21 +382,22 @@ def grpo_train(config):
389382 offload_to_cpu = False ,
390383 training_config = rl_cluster_lib .RLTrainingConfig (
391384 actor_optimizer = optimizer ,
392- eval_every_n_steps = getattr (config , ' eval_interval' , 10 ),
385+ eval_every_n_steps = getattr (config , " eval_interval" , 10 ),
393386 max_steps = max_steps ,
394387 metrics_logging_options = metrics_logging_options ,
395388 profiler_options = None ,
396389 checkpoint_root_directory = ckpt_dir ,
397390 checkpointing_options = checkpointing_options ,
398391 ),
399392 rollout_config = base_rollout .RolloutConfig (
400- max_tokens_to_generate = getattr (config , 'max_target_length' , 768 ),
401- max_prompt_length = getattr (config , 'max_prefill_predict_length' , 256 ),
402- kv_cache_size = getattr (config , 'max_prefill_predict_length' , 256 ) +
403- getattr (config , 'max_target_length' , 768 ) + 256 ,
404- temperature = getattr (config , 'decode_sampling_temperature' , 0.9 ),
405- top_p = getattr (config , 'decode_sampling_top_p' , 1.0 ),
406- top_k = getattr (config , 'decode_sampling_top_k' , 50 ),
393+ max_tokens_to_generate = getattr (config , "max_target_length" , 768 ),
394+ max_prompt_length = getattr (config , "max_prefill_predict_length" , 256 ),
395+ kv_cache_size = getattr (config , "max_prefill_predict_length" , 256 )
396+ + getattr (config , "max_target_length" , 768 )
397+ + 256 ,
398+ temperature = getattr (config , "decode_sampling_temperature" , 0.9 ),
399+ top_p = getattr (config , "decode_sampling_top_p" , 1.0 ),
400+ top_k = getattr (config , "decode_sampling_top_k" , 50 ),
407401 ),
408402 rollout_vllm_model_version = "meta-llama/Meta-Llama-3.1-8B-Instruct" ,
409403 rollout_vllm_hbm_utilization = 0.2 ,
@@ -412,10 +406,10 @@ def grpo_train(config):
412406
413407 # Setup GRPO config
414408 grpo_config = GrpoConfig (
415- num_generations = getattr (config , ' num_generations' , 2 ),
409+ num_generations = getattr (config , " num_generations" , 2 ),
416410 num_iterations = 1 ,
417- beta = getattr (config , ' grpo_beta' , 0.08 ),
418- epsilon = getattr (config , ' grpo_epsilon' , 0.2 ),
411+ beta = getattr (config , " grpo_beta" , 0.08 ),
412+ epsilon = getattr (config , " grpo_epsilon" , 0.2 ),
419413 )
420414
421415 # Create RL cluster
0 commit comments