diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 1a2b3783f..84ba79dd1 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -190,12 +190,14 @@ jobs: # TODO(lancewang): Re-enable this test once the segfault is fixed. # Run GRPO demo script with minimal configuration - # python3 scripts/grpo_demo_llama3_qwen2.py \ - # --root-dir=/tmp/grpo_test \ - # --model-version=Qwen/Qwen2.5-0.5B-Instruct \ - # --num-batches=1 \ - # --num-test-batches=1 \ - # --rollout-engine=vanilla + python3 scripts/grpo_demo_llama3_qwen2.py \ + --root-dir=/tmp/grpo_test \ + --num-batches=2 \ + --num-test-batches=1 \ + --train-global-batch-size=2 \ + --train-mini-batch-size=2 \ + --train-micro-batch-size=2 \ + --rollout-engine=vanilla - name: Run vllm tests env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/scripts/grpo_demo_llama3_qwen2.py b/scripts/grpo_demo_llama3_qwen2.py index a94f3e726..d9e401e8a 100644 --- a/scripts/grpo_demo_llama3_qwen2.py +++ b/scripts/grpo_demo_llama3_qwen2.py @@ -40,6 +40,7 @@ import qwix from tqdm.auto import tqdm import transformers +from tunix.cli.utils import data as data_lib from tunix.examples.data import math_dataset from tunix.models.llama3 import model as llama_lib from tunix.models.llama3 import params as llama_params @@ -54,6 +55,7 @@ from tunix.tests import test_common as tc from tunix.utils import script_utils + if os.getenv("JAX_PLATFORMS", None) == "proxy": import pathwaysutils @@ -573,11 +575,18 @@ def extract_hash_answer(text: str) -> str | None: dataset = create_dataset( args.data_source, args.dataset if args.data_source == "tfds" else LOCAL_TRAIN_DATA_DIR, - args.global_batch_size, - NUM_BATCHES, + tokenizer=model_tokenizer, tfds_download=True, ) +dataset = data_lib.post_init_dataset( + dataset, + model_tokenizer, + batch_size=args.global_batch_size, + num_batches=NUM_BATCHES, + max_prompt_length=MAX_PROMPT_LENGTH, +) + if TRAIN_FRACTION == 1.0: train_dataset = dataset.repeat(NUM_EPOCHS) val_dataset = None @@ -590,11 +599,18 @@ def extract_hash_answer(text: str) -> str | None: test_dataset = create_dataset( args.data_source, args.dataset if args.data_source == "tfds" else LOCAL_TRAIN_DATA_DIR, - args.global_batch_size, - NUM_TEST_BATCHES, + tokenizer=model_tokenizer, tfds_download=True, ) +test_dataset = data_lib.post_init_dataset( + test_dataset, + model_tokenizer, + batch_size=args.global_batch_size, + num_batches=NUM_TEST_BATCHES, + max_prompt_length=MAX_PROMPT_LENGTH, +) + print( f"train_dataset size: {len(train_dataset)}, val_dataset size:" f"{len(val_dataset) if val_dataset is not None else 0},"