diff --git a/scripts/grpo_demo_llama3_qwen2.py b/scripts/grpo_demo_llama3_qwen2.py index a94f3e726..58544b059 100644 --- a/scripts/grpo_demo_llama3_qwen2.py +++ b/scripts/grpo_demo_llama3_qwen2.py @@ -46,7 +46,6 @@ from tunix.models.qwen2 import model as qwen2_lib from tunix.models.qwen2 import params as qwen2_params from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.grpo import grpo_learner from tunix.rl.rollout import base_rollout from tunix.sft import metrics_logger from tunix.sft import profiler @@ -243,11 +242,36 @@ required=False, help="Name of dataset, required when data_source is tfds", ) - +parser.add_argument( + "--agentic-mode", + type=bool, + default=False, + required=False, + help="Whether to use agentic mode.", +) +parser.add_argument( + "--agentic-max-concurrency", + type=int, + default=4, + required=False, + help="Max concurrency for agentic mode.", +) +parser.add_argument( + "--agentic-off-policy-steps", + type=int, + default=0, + required=False, + help="Number of off-policy steps for agentic mode.", +) # Parse arguments args = parser.parse_args() +if args.agentic_mode: + from tunix.rl.experimental import agentic_grpo_learner as grpo_learner # pylint: disable=g-import-not-at-top +else: + from tunix.rl import grpo_learner # pylint: disable=g-import-not-at-top + def validata_args(): if args.data_source == "tfds": @@ -1100,12 +1124,22 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig: rollout_config=get_rollout_config(args.rollout_engine), ) -grpo_config = grpo_learner.GRPOConfig( - num_generations=NUM_GENERATIONS, - num_iterations=NUM_ITERATIONS, - beta=BETA, - epsilon=EPSILON, -) +if args.agentic_mode: + grpo_config = grpo_learner.GRPOConfig( + num_generations=NUM_GENERATIONS, + num_iterations=NUM_ITERATIONS, + beta=BETA, + epsilon=EPSILON, + max_concurrency=args.agentic_max_concurrency, + off_policy_steps=args.agentic_off_policy_steps, + ) +else: + grpo_config = grpo_learner.GRPOConfig( + num_generations=NUM_GENERATIONS, + num_iterations=NUM_ITERATIONS, + beta=BETA, + epsilon=EPSILON, + ) # RL cluster @@ -1161,7 +1195,7 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig: show_hbm_usage("Right before training") with training_mesh: - grpo_trainer.train(train_dataset, eval_ds=val_dataset) + grpo_trainer.train(train_dataset, val_dataset) # Load checkpoint first.