diff --git a/.github/workflows/tpu-nightly-regression.yml b/.github/workflows/tpu-nightly-regression.yml index 68522b7b..a3aa8066 100644 --- a/.github/workflows/tpu-nightly-regression.yml +++ b/.github/workflows/tpu-nightly-regression.yml @@ -83,66 +83,6 @@ jobs: print(f'SUCCESS: Found {len(devices)} TPU device(s)') " - - name: Run regression scripts - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - id: regression_tests - run: | - # Download GSM8K dataset - mkdir -p /tmp/grpo_test/rl/grpo/data - - FAILED=0 - echo "📦 Executing: examples/deepscaler/math_eval_nb.py..." - python examples/deepscaler/math_eval_nb.py || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vanilla rollout engine in colocated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vanilla rollout engine in 2 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --cluster-setup=disaggregated-2-way || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vanilla rollout engine in 3 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --cluster-setup=disaggregated-3-way || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm rollout engine in colocated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm rollout engine in 2 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --cluster-setup=disaggregated-2-way || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm rollout engine in 3 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --cluster-setup=disaggregated-3-way || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm server mode in 2 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --rollout-server-mode=True --cluster-setup=disaggregated-2-way || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm server mode in 3 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --rollout-server-mode=True --cluster-setup=disaggregated-3-way || FAILED=1 - - # SGLang Tests - unset JAX_PLATFORMS - pip list | egrep 'jax|flax|libtpu' - cd .. - git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . && cd ../.. - pip install jax==0.8.1 flax==0.12.0 libtpu==0.0.24 - pip list | egrep 'jax|flax|libtpu' - cd tunix - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in colocated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 2 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-2-way || FAILED=1 - - echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 3 way disaggregated mode ..." - python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-3-way || FAILED=1 - - - if [ "$FAILED" -ne 0 ]; then - echo "One or more scripts failed!" - exit 1 - fi - - name: Run SFT shell scripts env: HF_TOKEN: ${{ secrets.HF_TOKEN }} @@ -218,5 +158,65 @@ jobs: fi echo "🎉 All RL scripts completed successfully." + - name: Run regression scripts + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + id: regression_tests + run: | + # Download GSM8K dataset + mkdir -p /tmp/grpo_test/rl/grpo/data + + FAILED=0 + echo "📦 Executing: examples/deepscaler/math_eval_nb.py..." + python examples/deepscaler/math_eval_nb.py || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vanilla rollout engine in colocated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vanilla rollout engine in 2 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --cluster-setup=disaggregated-2-way || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vanilla rollout engine in 3 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --cluster-setup=disaggregated-3-way || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm rollout engine in colocated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm rollout engine in 2 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --cluster-setup=disaggregated-2-way || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm rollout engine in 3 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --cluster-setup=disaggregated-3-way || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm server mode in 2 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --rollout-server-mode=True --cluster-setup=disaggregated-2-way || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with vllm server mode in 3 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=vllm --rollout-server-mode=True --cluster-setup=disaggregated-3-way || FAILED=1 + + # SGLang Tests + unset JAX_PLATFORMS + pip list | egrep 'jax|flax|libtpu' + cd .. + git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . && cd ../.. + pip install jax==0.8.1 flax==0.12.0 libtpu==0.0.24 + pip list | egrep 'jax|flax|libtpu' + cd tunix + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in colocated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 2 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-2-way || FAILED=1 + + echo "📦 Executing: scripts/grpo_demo_llama3_qwen2.py with sglang_jax in 3 way disaggregated mode ..." + python scripts/grpo_demo_llama3_qwen2.py --root-dir=/tmp/grpo_test --num-batches=20 --rollout-engine=sglang_jax --cluster-setup=disaggregated-3-way || FAILED=1 + + + if [ "$FAILED" -ne 0 ]; then + echo "One or more scripts failed!" + exit 1 + fi + \ No newline at end of file