Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions scripts/grpo_demo_llama3_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from tunix.models.qwen2 import model as qwen2_lib
from tunix.models.qwen2 import params as qwen2_params
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo import grpo_learner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.sft import profiler
Expand Down Expand Up @@ -243,11 +242,36 @@
required=False,
help="Name of dataset, required when data_source is tfds",
)

parser.add_argument(
"--agentic-mode",
type=bool,
default=False,
required=False,
help="Whether to use agentic mode.",
)
parser.add_argument(
"--agentic-max-concurrency",
type=int,
default=4,
required=False,
help="Max concurrency for agentic mode.",
)
parser.add_argument(
"--agentic-off-policy-steps",
type=int,
default=0,
required=False,
help="Number of off-policy steps for agentic mode.",
)

# Parse arguments
args = parser.parse_args()

if args.agentic_mode:
from tunix.rl.experimental import agentic_grpo_learner as grpo_learner # pylint: disable=g-import-not-at-top
else:
from tunix.rl import grpo_learner # pylint: disable=g-import-not-at-top


def validata_args():
if args.data_source == "tfds":
Expand Down Expand Up @@ -1100,12 +1124,22 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig:
rollout_config=get_rollout_config(args.rollout_engine),
)

grpo_config = grpo_learner.GRPOConfig(
num_generations=NUM_GENERATIONS,
num_iterations=NUM_ITERATIONS,
beta=BETA,
epsilon=EPSILON,
)
if args.agentic_mode:
grpo_config = grpo_learner.GRPOConfig(
num_generations=NUM_GENERATIONS,
num_iterations=NUM_ITERATIONS,
beta=BETA,
epsilon=EPSILON,
max_concurrency=args.agentic_max_concurrency,
off_policy_steps=args.agentic_off_policy_steps,
)
else:
grpo_config = grpo_learner.GRPOConfig(
num_generations=NUM_GENERATIONS,
num_iterations=NUM_ITERATIONS,
beta=BETA,
epsilon=EPSILON,
)


# RL cluster
Expand Down Expand Up @@ -1161,7 +1195,7 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig:

show_hbm_usage("Right before training")
with training_mesh:
grpo_trainer.train(train_dataset, eval_ds=val_dataset)
grpo_trainer.train(train_dataset, val_dataset)

# Load checkpoint first.

Expand Down
Loading