Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ jobs:
--ignore=tests/generate/vllm_sampler_test.py \
--ignore=tests/generate/vllm_driver_test.py \
--ignore=tests/generate/tokenizer_adapter_test.py \
--ignore=tests/generate/sglang_jax_sampler_test.py
--ignore=tests/generate/sglang_jax_sampler_test.py \
--ignore=tests/generate/sglang_jax_lora_test.py

- name: Run tunix SFT tests
run: |
Expand Down Expand Up @@ -208,7 +209,7 @@ jobs:
EOF
apt-get update; apt-get install -y less

cd tunix && python tests/generate/sglang_jax_sampler_test.py
cd tunix && python tests/generate/sglang_jax_sampler_test.py && python tests/generate/sglang_jax_lora_test.py
- name: Run tunix SFT integration tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand Down
54 changes: 49 additions & 5 deletions scripts/grpo_demo_llama3_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import transformers
from tunix.cli.utils import data as data_lib
from tunix.examples.data import math_dataset
from tunix.generate import mappings
from tunix.models.llama3 import model as llama_lib
from tunix.models.llama3 import params as llama_params
from tunix.models.qwen2 import model as qwen2_lib
Expand Down Expand Up @@ -244,6 +245,34 @@
required=False,
help="Name of dataset, required when data_source is tfds",
)
parser.add_argument(
"--enable-lora",
action="store_true",
default=False,
required=False,
help="Enable LoRA.",
)
parser.add_argument(
"--lora-rank",
type=int,
default=64,
required=False,
help="Rank of LoRA.",
)
parser.add_argument(
"--lora-alpha",
type=float,
default=64.0,
required=False,
help="Alpha of LoRA.",
)
parser.add_argument(
"--lora-target-modules",
nargs="+",
type=str,
default=None,
help="List of target modules to apply LoRA",
)


# Parse arguments
Expand Down Expand Up @@ -292,9 +321,14 @@ def validata_args():
SEED = 42

# ====== LoRA ======
ENABLE_LORA = False
RANK = 64
ALPHA = 64.0
ENABLE_LORA = args.enable_lora
RANK = args.lora_rank
ALPHA = args.lora_alpha
LORA_TARGET_MODULES = args.lora_target_modules
if ENABLE_LORA and LORA_TARGET_MODULES is None:
raise ValueError(
f"{LORA_TARGET_MODULES} can not be None when LoRA is enabled!"
)

# ====== Sharding ======
if "Qwen2.5-0.5B-Instruct" in args.model_version:
Expand Down Expand Up @@ -684,7 +718,7 @@ def get_lora_model(base_model, model_mesh=None):
# Policy model
# TODO(b/434959964): Supports lora in vLLM Jax backend
if ENABLE_LORA:
training_model = get_lora_model(ref_model, model_mesh=mesh)
training_model = get_lora_model(ref_model)
training_mesh = ref_mesh
else:
training_model, training_mesh, _ = get_model(
Expand Down Expand Up @@ -1040,13 +1074,16 @@ def evaluate(

def get_rollout_config(engine: str) -> base_rollout.RolloutConfig:
if engine == "sglang_jax":
return base_rollout.RolloutConfig(
config = base_rollout.RolloutConfig(
max_tokens_to_generate=TOTAL_GENERATION_STEPS,
max_prompt_length=MAX_PROMPT_LENGTH,
kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
temperature=TEMPERATURE,
top_p=TOP_P,
top_k=TOP_K,
rollout_mapping_config=mappings.MappingConfig.build(
model=ref_model, backend="sglang_jax"
),
rollout_sglang_jax_model_version=SGLANGJAX_MODEL_VERSION,
rollout_sglang_jax_mem_fraction_static=0.2,
rollout_sglang_jax_init_with_random_weights=True,
Expand All @@ -1057,6 +1094,13 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig:
rollout_sglang_jax_chunked_prefill_size=2048,
rollout_sglang_jax_page_size=64,
)
if ENABLE_LORA:
config.rollout_sglang_jax_enable_static_lora = True
config.rollout_sglang_jax_lora_target_modules = LORA_TARGET_MODULES
config.rollout_sglang_jax_max_lora_rank = RANK
config.rollout_sglang_jax_lora_scaling = ALPHA / RANK

return config

return base_rollout.RolloutConfig(
max_tokens_to_generate=TOTAL_GENERATION_STEPS,
Expand Down
Loading