Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/workflows/RunTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ concurrency:
}}
cancel-in-progress: true

permissions:
contents: read

jobs:
prelim:
runs-on: ["self-hosted"]
Expand Down Expand Up @@ -103,6 +106,53 @@ jobs:
container_resource_option: "--privileged"
is_scheduled_run: ${{ github.event_name == 'schedule' }}

tpu_e2e_grpo_test:
needs: tpu_image
runs-on: linux-x86-ct4p-240-4tpu
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75
TF_FORCE_GPU_ALLOW_GROWTH: false
HF_TOKEN: ${{ secrets.HF_TOKEN }}
MAXTEXT_CHECKPOINT_PATH: gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the instruct checkpoint: gs://maxtext-model-checkpoints/llama3.1_8b_instruct/2025-10-16/scanned/0/items

options: "--privileged"
steps:
- uses: actions/checkout@v4
- name: Install Tunix vLLM Requirements
run: |
bash src/MaxText/examples/install_tunix_vllm_requirement.sh
- name: Run GRPO Llama3.1 8B Demo
run: |
python3 -m pip install -e . --no-dependencies &&
python3 src/MaxText/examples/grpo_llama3_1_8b_demo.py

tpu_e2e_sft_test:
needs: tpu_image
runs-on: linux-x86-ct4p-240-4tpu
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75
TF_FORCE_GPU_ALLOW_GROWTH: false
HF_TOKEN: ${{ secrets.HF_TOKEN }}
MODEL_CHECKPOINT_PATH: gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instruct checkpoint: gs://maxtext-model-checkpoints/llama3.1_8b_instruct/2025-10-16/scanned/0/items

STEPS: 10
options: "--privileged"
steps:
- uses: actions/checkout@v4
- name: Install Dependencies
run: |
python3 -m pip install -e . --no-dependencies
- name: Install Tunix vLLM Requirements
run: |
bash src/MaxText/examples/install_tunix_vllm_requirement.sh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need vllm or tpu-commons to run SFT.

- name: Run SFT Llama3.1 8B Demo
run: |
python3 src/MaxText/examples/sft_llama3_demo.py \
--skip_checkpoint_download \
--model_checkpoint_path=${MODEL_CHECKPOINT_PATH}

gpu_unit_tests:
needs: gpu_image
uses: ./.github/workflows/run_tests_internal.yml
Expand Down
69 changes: 57 additions & 12 deletions src/MaxText/examples/grpo_llama3_1_8b_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,58 @@
# nest_asyncio.apply() # To fix "This event loop is already running" error in Colab
# Run `pip install nest_asyncio` if not already installed.

jax.devices()

DEBUG = False # set to True to run in debug mode, for more print statements

HOME = os.path.join(os.path.expanduser("~"), "")
print(f"Home directory (from Python): {HOME}")

# Determine MaxText repo root - use environment variable if set, otherwise derive from script location
if "MAXTEXT_REPO_ROOT" in os.environ:
MAXTEXT_ROOT = os.environ["MAXTEXT_REPO_ROOT"]
else:
# This script is in src/MaxText/examples/, so go up 3 levels to get repo root
MAXTEXT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
print(f"MaxText repo root: {MAXTEXT_ROOT}")

# Use pre-existing checkpoint from env var or convert from HuggingFace
# Check for MAXTEXT_CHECKPOINT_PATH env var first (for CI/testing)
if "MAXTEXT_CHECKPOINT_PATH" in os.environ:
CHECKPOINT_LOAD_PATH = os.environ["MAXTEXT_CHECKPOINT_PATH"]
print(f"Using checkpoint from environment variable: {CHECKPOINT_LOAD_PATH}")
else:
# Convert checkpoint from HuggingFace to MaxText format BEFORE any JAX initialization
# This must happen early to avoid TPU conflicts
MODEL_NAME = "llama3.1-8b-Instruct" # Use Instruct version for GRPO
MODEL_CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), "checkpoints", MODEL_NAME)

if not os.path.exists(MODEL_CHECKPOINT_PATH):
print(f"Converting checkpoint from HuggingFace to MaxText format at {MODEL_CHECKPOINT_PATH}")
os.makedirs(MODEL_CHECKPOINT_PATH, exist_ok=True)
import subprocess

result = subprocess.run(
[
"python3",
"-m",
"MaxText.utils.ckpt_conversion.to_maxtext",
os.path.join(MAXTEXT_ROOT, "src/MaxText/configs/base.yml"),
f"model_name={MODEL_NAME}",
f"base_output_directory={MODEL_CHECKPOINT_PATH}",
f"hf_access_token={os.environ.get('HF_TOKEN', '')}",
"use_multimodal=false",
"scan_layers=true",
],
check=True,
)
print("Checkpoint conversion completed successfully")
else:
print(f"Using existing checkpoint at {MODEL_CHECKPOINT_PATH}")

CHECKPOINT_LOAD_PATH = os.path.join(MODEL_CHECKPOINT_PATH, "0", "items")

# Initialize JAX/TPU after checkpoint conversion
jax.devices()

# ## Hyperparameters
#
# Let's define the configuration we are going to use. Note that this is by no
Expand Down Expand Up @@ -287,12 +332,14 @@ def get_dataset(data_dir, split="train") -> grain.MapDataset:
if not os.path.exists(data_dir):
os.makedirs(data_dir)

# Use try_gcs=True to leverage Google's cached datasets and avoid slow downloads
data = tfds.data_source(
"gsm8k",
split=split,
data_dir=data_dir,
builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
download=True,
try_gcs=True, # Use GCS cached version if available for faster loading
)

loaded_dataset = (
Expand Down Expand Up @@ -396,18 +443,17 @@ def get_ref_maxtext_model(config):
model_config = llama3_lib.ModelConfig.llama3_1_8b()

# Load the reference model
# Note: pass the path to your scanned checkpoint for "load_parameters_path".
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py
print(f"Loading reference model checkpoint from: {CHECKPOINT_LOAD_PATH}")
config_ref = pyconfig.initialize(
[
"",
f"{HOME}/maxtext/src/MaxText/configs/base.yml",
os.path.join(MAXTEXT_ROOT, "src/MaxText/configs/base.yml"),
],
base_output_directory="dummy", # This is not used in Tunix.
run_name="test-tunix-maxtext-llama3.1-8b",
tokenizer_type="tiktoken",
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
load_parameters_path=MODEL_CHECKPOINT_PATH,
load_parameters_path=CHECKPOINT_LOAD_PATH,
per_device_batch_size=1,
max_prefill_predict_length=4,
max_target_length=1024,
Expand Down Expand Up @@ -452,21 +498,20 @@ def get_ref_maxtext_model(config):


# Load the policy model
# Note: pass the path to your scanned checkpoint for "load_parameters_path".
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py
# Note: Model checkpoint will be loaded from converted HuggingFace weights

# TODO: @mazumdera: change this to use lora

config_policy = pyconfig.initialize(
[
"",
f"{HOME}/maxtext/src/MaxText/configs/base.yml",
os.path.join(MAXTEXT_ROOT, "src/MaxText/configs/base.yml"),
],
base_output_directory="dummy", # This is not used in Tunix.
run_name="test-tunix-maxtext-llama3.1-8b", # This is not used in Tunix.
tokenizer_type="tiktoken",
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer_llama3.tiktoken"),
load_parameters_path=MODEL_CHECKPOINT_PATH,
load_parameters_path=CHECKPOINT_LOAD_PATH,
per_device_batch_size=1,
max_prefill_predict_length=4,
max_target_length=1024,
Expand Down Expand Up @@ -954,7 +999,7 @@ def main():
actor_optimizer=optimizer,
eval_every_n_steps=EVAL_EVERY_N_STEPS,
max_steps=MAX_STEPS,
gradient_accumulation_steps=1,
# gradient_accumulation_steps is automatically derived for RL training
# metrics logging
metrics_logging_options=metrics_logging_options,
# checkpoint saving
Expand All @@ -970,7 +1015,7 @@ def main():
top_k=TOP_K,
),
rollout_vllm_model_version="meta-llama/Meta-Llama-3.1-8B-Instruct",
rollout_vllm_hbm_utilization=0.2,
rollout_vllm_hbm_utilization=0.5, # Increased from 0.2 to allow vLLM to use more memory
rollout_vllm_tpu_backend_type="jax",
)

Expand Down
Loading
Loading