Skip to content

Commit c46a93c

Browse files
mydatascienceA9isha
authored andcommitted
Fix
Signed-off-by: Vladimir Suvorov <[email protected]>
1 parent f1c0eed commit c46a93c

File tree

2 files changed

+29
-35
lines changed

2 files changed

+29
-35
lines changed

src/MaxText/examples/grpo_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def build_config_argv(args):
298298
# Add optional parameters
299299
if args.run_name:
300300
config_argv.append(f"run_name={args.run_name}")
301-
301+
302302
if args.hf_access_token:
303303
config_argv.append(f"hf_access_token={args.hf_access_token}")
304304

@@ -360,4 +360,4 @@ def main():
360360

361361

362362
if __name__ == "__main__":
363-
main()
363+
main()

src/MaxText/experimental/rl/grpo_tunix_trainer.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ def extract_hash_answer(text: str) -> str | None:
9393
return text.split("####")[1].strip()
9494

9595

96-
def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1,
97-
num_batches: int = 4, seed: int = 42):
96+
def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, num_batches: int = 4, seed: int = 42):
9897
"""Load and process GSM8K dataset for GRPO training."""
9998
if grain is None:
10099
raise ImportError("grain is required for dataset processing. Please install it.")
@@ -159,9 +158,9 @@ def setup_device_allocation(config, use_pathways: bool = False):
159158
devices = jax.devices()
160159

161160
# Get device allocation parameters from config
162-
trainer_devices_fraction = getattr(config, 'trainer_devices_fraction', 0.5)
163-
sampler_devices_fraction = getattr(config, 'sampler_devices_fraction', 0.5)
164-
chips_per_vm = getattr(config, 'chips_per_vm', 4)
161+
trainer_devices_fraction = getattr(config, "trainer_devices_fraction", 0.5)
162+
sampler_devices_fraction = getattr(config, "sampler_devices_fraction", 0.5)
163+
chips_per_vm = getattr(config, "chips_per_vm", 4)
165164

166165
num_vms = len(devices) // chips_per_vm
167166

@@ -175,7 +174,7 @@ def setup_device_allocation(config, use_pathways: bool = False):
175174
num_trainer_devices = int(num_devices * trainer_devices_fraction)
176175
num_sampler_devices = int(num_devices * sampler_devices_fraction)
177176
trainer_devices = devices[:num_trainer_devices]
178-
sampler_devices = devices[num_devices - num_sampler_devices:]
177+
sampler_devices = devices[num_devices - num_sampler_devices :]
179178
else:
180179
# Not using Pathways OR single host - use all devices for both
181180
if use_pathways:
@@ -194,8 +193,7 @@ def match_format_exactly(prompts, completions, **kwargs):
194193
"""Reward exact format matching."""
195194
scores = []
196195
match_format = re.compile(
197-
rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?"
198-
rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$",
196+
rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$",
199197
flags=re.MULTILINE | re.DOTALL,
200198
)
201199

@@ -223,8 +221,7 @@ def match_format_approximately(prompts, completions, **kwargs):
223221
def check_answer(prompts, completions, answer, **kwargs):
224222
"""Reward correct answers."""
225223
match_format = re.compile(
226-
rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?"
227-
rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$",
224+
rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$",
228225
flags=re.MULTILINE | re.DOTALL,
229226
)
230227

@@ -288,7 +285,7 @@ def grpo_train(config):
288285
print("=" * 80)
289286

290287
# Setup device allocation
291-
use_pathways = getattr(config, 'use_pathways_reshard', False)
288+
use_pathways = getattr(config, "use_pathways_reshard", False)
292289
trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways)
293290

294291
print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler")
@@ -305,15 +302,15 @@ def grpo_train(config):
305302
train_data_dir,
306303
split="train",
307304
batch_size=config.per_device_batch_size,
308-
num_batches=getattr(config, 'num_batches', 4)
305+
num_batches=getattr(config, "num_batches", 4),
309306
)
310307

311308
# Load test dataset for evaluation (currently not used in training loop)
312309
get_gsm8k_dataset(
313310
test_data_dir,
314311
split="test",
315312
batch_size=config.per_device_batch_size,
316-
num_batches=getattr(config, 'num_test_batches', 5)
313+
num_batches=getattr(config, "num_test_batches", 5),
317314
)
318315

319316
# Load reference model
@@ -335,8 +332,8 @@ def grpo_train(config):
335332
rollout_mesh = policy_mesh
336333

337334
# Setup optimizer
338-
learning_rate = getattr(config, 'learning_rate', 3e-6)
339-
max_steps = getattr(config, 'steps', 100)
335+
learning_rate = getattr(config, "learning_rate", 3e-6)
336+
max_steps = getattr(config, "steps", 100)
340337
warmup_steps = int(0.1 * max_steps)
341338

342339
optimizer = optax.adamw(
@@ -353,7 +350,7 @@ def grpo_train(config):
353350
)
354351

355352
# Add gradient clipping if specified
356-
max_grad_norm = getattr(config, 'max_grad_norm', 0.1)
353+
max_grad_norm = getattr(config, "max_grad_norm", 0.1)
357354
if max_grad_norm is not None:
358355
optimizer = optax.chain(
359356
optax.clip_by_global_norm(max_norm=max_grad_norm),
@@ -365,18 +362,14 @@ def grpo_train(config):
365362
os.makedirs(ckpt_dir, exist_ok=True)
366363

367364
checkpointing_options = ocp.CheckpointManagerOptions(
368-
save_interval_steps=getattr(config, 'checkpoint_period', 50),
369-
max_to_keep=4
365+
save_interval_steps=getattr(config, "checkpoint_period", 50), max_to_keep=4
370366
)
371367

372368
# Setup metrics logging
373369
log_dir = f"{config.base_output_directory}/logs"
374370
os.makedirs(log_dir, exist_ok=True)
375371

376-
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
377-
log_dir=log_dir,
378-
flush_every_n_steps=20
379-
)
372+
metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20)
380373

381374
# Setup RL cluster config
382375
cluster_config = rl_cluster_lib.ClusterConfig(
@@ -389,21 +382,22 @@ def grpo_train(config):
389382
offload_to_cpu=False,
390383
training_config=rl_cluster_lib.RLTrainingConfig(
391384
actor_optimizer=optimizer,
392-
eval_every_n_steps=getattr(config, 'eval_interval', 10),
385+
eval_every_n_steps=getattr(config, "eval_interval", 10),
393386
max_steps=max_steps,
394387
metrics_logging_options=metrics_logging_options,
395388
profiler_options=None,
396389
checkpoint_root_directory=ckpt_dir,
397390
checkpointing_options=checkpointing_options,
398391
),
399392
rollout_config=base_rollout.RolloutConfig(
400-
max_tokens_to_generate=getattr(config, 'max_target_length', 768),
401-
max_prompt_length=getattr(config, 'max_prefill_predict_length', 256),
402-
kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) +
403-
getattr(config, 'max_target_length', 768) + 256,
404-
temperature=getattr(config, 'decode_sampling_temperature', 0.9),
405-
top_p=getattr(config, 'decode_sampling_top_p', 1.0),
406-
top_k=getattr(config, 'decode_sampling_top_k', 50),
393+
max_tokens_to_generate=getattr(config, "max_target_length", 768),
394+
max_prompt_length=getattr(config, "max_prefill_predict_length", 256),
395+
kv_cache_size=getattr(config, "max_prefill_predict_length", 256)
396+
+ getattr(config, "max_target_length", 768)
397+
+ 256,
398+
temperature=getattr(config, "decode_sampling_temperature", 0.9),
399+
top_p=getattr(config, "decode_sampling_top_p", 1.0),
400+
top_k=getattr(config, "decode_sampling_top_k", 50),
407401
),
408402
rollout_vllm_model_version="meta-llama/Meta-Llama-3.1-8B-Instruct",
409403
rollout_vllm_hbm_utilization=0.2,
@@ -412,10 +406,10 @@ def grpo_train(config):
412406

413407
# Setup GRPO config
414408
grpo_config = GrpoConfig(
415-
num_generations=getattr(config, 'num_generations', 2),
409+
num_generations=getattr(config, "num_generations", 2),
416410
num_iterations=1,
417-
beta=getattr(config, 'grpo_beta', 0.08),
418-
epsilon=getattr(config, 'grpo_epsilon', 0.2),
411+
beta=getattr(config, "grpo_beta", 0.08),
412+
epsilon=getattr(config, "grpo_epsilon", 0.2),
419413
)
420414

421415
# Create RL cluster

0 commit comments

Comments
 (0)