diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml
index 9c753a758..4b84aec7f 100644
--- a/.github/workflows/tpu-tests.yml
+++ b/.github/workflows/tpu-tests.yml
@@ -98,7 +98,8 @@ jobs:
--ignore=tests/generate/vllm_sampler_test.py \
--ignore=tests/generate/vllm_driver_test.py \
--ignore=tests/generate/tokenizer_adapter_test.py \
- --ignore=tests/generate/sglang_jax_sampler_test.py
+ --ignore=tests/generate/sglang_jax_sampler_test.py \
+ --ignore=tests/generate/sglang_jax_lora_test.py
- name: Run tunix SFT tests
run: |
@@ -208,7 +209,7 @@ jobs:
EOF
apt-get update; apt-get install -y less
- cd tunix && python tests/generate/sglang_jax_sampler_test.py
+ cd tunix && python tests/generate/sglang_jax_sampler_test.py && python tests/generate/sglang_jax_lora_test.py
- name: Run tunix SFT integration tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
diff --git a/scripts/grpo_demo_llama3_qwen2.py b/scripts/grpo_demo_llama3_qwen2.py
index fc9acbd02..e1cb62b23 100644
--- a/scripts/grpo_demo_llama3_qwen2.py
+++ b/scripts/grpo_demo_llama3_qwen2.py
@@ -42,6 +42,7 @@
import transformers
from tunix.cli.utils import data as data_lib
from tunix.examples.data import math_dataset
+from tunix.generate import mappings
from tunix.models.llama3 import model as llama_lib
from tunix.models.llama3 import params as llama_params
from tunix.models.qwen2 import model as qwen2_lib
@@ -244,6 +245,34 @@
required=False,
help="Name of dataset, required when data_source is tfds",
)
+parser.add_argument(
+ "--enable-lora",
+ action="store_true",
+ default=False,
+ required=False,
+ help="Enable LoRA.",
+)
+parser.add_argument(
+ "--lora-rank",
+ type=int,
+ default=64,
+ required=False,
+ help="Rank of LoRA.",
+)
+parser.add_argument(
+ "--lora-alpha",
+ type=float,
+ default=64.0,
+ required=False,
+ help="Alpha of LoRA.",
+)
+parser.add_argument(
+ "--lora-target-modules",
+ nargs="+",
+ type=str,
+ default=None,
+ help="List of target modules to apply LoRA",
+)
# Parse arguments
@@ -292,9 +321,14 @@ def validata_args():
SEED = 42
# ====== LoRA ======
-ENABLE_LORA = False
-RANK = 64
-ALPHA = 64.0
+ENABLE_LORA = args.enable_lora
+RANK = args.lora_rank
+ALPHA = args.lora_alpha
+LORA_TARGET_MODULES = args.lora_target_modules
+if ENABLE_LORA and LORA_TARGET_MODULES is None:
+ raise ValueError(
+ f"{LORA_TARGET_MODULES} can not be None when LoRA is enabled!"
+ )
# ====== Sharding ======
if "Qwen2.5-0.5B-Instruct" in args.model_version:
@@ -684,7 +718,7 @@ def get_lora_model(base_model, model_mesh=None):
# Policy model
# TODO(b/434959964): Supports lora in vLLM Jax backend
if ENABLE_LORA:
- training_model = get_lora_model(ref_model, model_mesh=mesh)
+ training_model = get_lora_model(ref_model)
training_mesh = ref_mesh
else:
training_model, training_mesh, _ = get_model(
@@ -1040,13 +1074,16 @@ def evaluate(
def get_rollout_config(engine: str) -> base_rollout.RolloutConfig:
if engine == "sglang_jax":
- return base_rollout.RolloutConfig(
+ config = base_rollout.RolloutConfig(
max_tokens_to_generate=TOTAL_GENERATION_STEPS,
max_prompt_length=MAX_PROMPT_LENGTH,
kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
temperature=TEMPERATURE,
top_p=TOP_P,
top_k=TOP_K,
+ rollout_mapping_config=mappings.MappingConfig.build(
+ model=ref_model, backend="sglang_jax"
+ ),
rollout_sglang_jax_model_version=SGLANGJAX_MODEL_VERSION,
rollout_sglang_jax_mem_fraction_static=0.2,
rollout_sglang_jax_init_with_random_weights=True,
@@ -1057,6 +1094,13 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig:
rollout_sglang_jax_chunked_prefill_size=2048,
rollout_sglang_jax_page_size=64,
)
+ if ENABLE_LORA:
+ config.rollout_sglang_jax_enable_static_lora = True
+ config.rollout_sglang_jax_lora_target_modules = LORA_TARGET_MODULES
+ config.rollout_sglang_jax_max_lora_rank = RANK
+ config.rollout_sglang_jax_lora_scaling = ALPHA / RANK
+
+ return config
return base_rollout.RolloutConfig(
max_tokens_to_generate=TOTAL_GENERATION_STEPS,
diff --git a/tests/generate/sglang_jax_lora_test.py b/tests/generate/sglang_jax_lora_test.py
new file mode 100644
index 000000000..516399894
--- /dev/null
+++ b/tests/generate/sglang_jax_lora_test.py
@@ -0,0 +1,371 @@
+"""
+Note: This test is based on scripts/grpo_demo_sglang_jax_rollout.py.
+For the meanings of constants, please refer to the above file.
+"""
+
+import csv
+import os
+from pathlib import Path
+import re
+import shutil
+from absl.testing import absltest
+
+import grain
+import huggingface_hub
+import jax
+import kagglehub
+import optax
+from orbax import checkpoint as ocp
+import qwix
+import transformers
+from tunix.generate import mappings
+from tunix.models.llama3 import model as llama_lib
+from tunix.models.llama3 import params as llama3_params_lib
+from tunix.rl import rl_cluster as rl_cluster_lib
+from tunix.rl.grpo.grpo_learner import GRPOConfig
+from tunix.rl.grpo.grpo_learner import GRPOLearner
+from tunix.rl.rollout import base_rollout
+from tunix.rl.utils import VERIFY_UPDATE_PARAMS_KEY
+from tunix.sft import metrics_logger
+
+
+############################################# CONSTANTS ###########################################
+
+TRAIN_DATA_DIR = "./data/train"
+TRAIN_FRACTION = 1.0
+RANK = 64
+ALPHA = 64.0
+TOTAL_TPU_TO_USE = jax.device_count()
+MESH = [
+ (
+ 1,
+ TOTAL_TPU_TO_USE,
+ ),
+ ("fsdp", "tp"),
+]
+MAX_PROMPT_LENGTH = 256
+TOTAL_GENERATION_STEPS = 1024
+TEMPERATURE = 0.9
+TOP_P = 1.0
+TOP_K = 50
+NUM_GENERATIONS = 2
+NUM_ITERATIONS = 1
+BETA = 0.08
+EPSILON = 0.2
+TRAIN_MICRO_BATCH_SIZE = 1
+NUM_BATCHES = 2
+NUM_TEST_BATCHES = 2
+EVAL_EVERY_N_STEPS = 5
+NUM_EPOCHS = 1
+MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
+LEARNING_RATE = 3e-6
+B1 = 0.9
+B2 = 0.99
+WEIGHT_DECAY = 0.1
+WARMUP_STEPS = 0.1 * MAX_STEPS
+MAX_GRAD_NORM = 0.1
+INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
+CKPT_DIR = "/tmp/content/ckpts/"
+SAVE_INTERVAL_STEPS = 500
+MAX_TO_KEEP = 4
+
+############################################# PROMPTS ###########################################
+
+reasoning_start = ""
+reasoning_end = ""
+solution_start = ""
+solution_end = ""
+
+
+SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
+provide your reasoning. Place it between {reasoning_start} and \
+{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
+value) between {solution_start} and {solution_end}."""
+
+TEMPLATE = """user
+{system_prompt}
+
+{question}
+model"""
+
+
+def extract_hash_answer(text: str) -> str | None:
+ if "####" not in text:
+ return None
+ return text.split("####")[1].strip()
+
+
+def download_kaggle_dataset(target_dir="./data/gsm8k"):
+ os.makedirs(target_dir, exist_ok=True)
+ src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
+ src = Path(src)
+ dst = Path(target_dir)
+
+ for csv_file in src.glob("*.csv"): # match all CSV files
+ shutil.copy2(csv_file, dst / csv_file.name)
+ print(f"Copied {csv_file.name} → {dst/csv_file.name}")
+ return target_dir
+
+
+def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
+ # Download data
+ if not os.path.exists(data_dir):
+ os.makedirs(data_dir)
+
+ kaggle_dir = download_kaggle_dataset(data_dir)
+ file_name = "main_" + split + ".csv"
+ csv_path = os.path.join(kaggle_dir, file_name) # adjust filename if needed
+
+ data = []
+ with open(csv_path, newline="", encoding="utf-8") as csvfile:
+ reader = csv.DictReader(csvfile)
+ for row in reader:
+ data.append({
+ "question": row["question"],
+ "answer": row["answer"],
+ })
+
+ def _as_text(v):
+ return v if isinstance(v, str) else v.decode("utf-8")
+
+ dataset = (
+ grain.MapDataset.source(data)
+ .shuffle(seed=42)
+ .map(
+ lambda x: {
+ # passed to model forward pass
+ "prompts": model_tokenizer.apply_chat_template(
+ [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": x["question"]},
+ ],
+ tokenize=False,
+ add_generation_prompt=True,
+ ),
+ # passed to reward functions
+ "question": _as_text(x["question"]),
+ # passed to reward functions
+ "answer": extract_hash_answer(_as_text(x["answer"])),
+ }
+ )
+ )
+ return dataset
+
+
+def download_from_huggingface(repo_id: str, model_path: str):
+ """Download checkpoint files from huggingface."""
+ print("Make sure you logged in to the huggingface cli.")
+ all_files = huggingface_hub.list_repo_files(repo_id)
+ filtered_files = [f for f in all_files if not f.startswith("original/")]
+
+ for filename in filtered_files:
+ huggingface_hub.hf_hub_download(
+ repo_id=repo_id, filename=filename, local_dir=model_path
+ )
+ print(f"Downloaded {filtered_files} to: {model_path}")
+
+
+def load_model():
+ model_config = llama_lib.ModelConfig.llama3p2_3b()
+
+ mesh = jax.make_mesh(
+ *MESH,
+ devices=jax.devices()[:TOTAL_TPU_TO_USE],
+ axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]),
+ )
+ model = llama3_params_lib.create_model_from_safe_tensors(
+ model_path, model_config, mesh
+ )
+ return model, mesh, model_config
+
+
+def get_rollout_mesh():
+ mesh = jax.make_mesh(
+ *MESH,
+ devices=jax.devices()[-TOTAL_TPU_TO_USE:],
+ axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]),
+ )
+ return mesh
+
+
+def get_lora_model(base_model):
+ lora_provider = qwix.LoraProvider(
+ module_path=".*gate_proj",
+ rank=RANK,
+ alpha=ALPHA,
+ )
+
+ model_input = base_model.get_model_input()
+ lora_model = qwix.apply_lora_to_model(
+ base_model, lora_provider, **model_input
+ )
+
+ return lora_model
+
+
+model_version = "meta-llama/Llama-3.2-3B-Instruct"
+
+repo_id = model_version
+model_tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)
+
+import tempfile
+
+temp_dir = tempfile.gettempdir()
+model_path = os.path.join(temp_dir, "models", repo_id)
+
+match_format = re.compile(
+ rf"^[\s]{{0,}}"
+ rf"{reasoning_start}.+?{reasoning_end}.*?"
+ rf"{solution_start}(.+?){solution_end}"
+ rf"[\s]{{0,}}$",
+ flags=re.MULTILINE | re.DOTALL,
+)
+
+
+def match_format_exactly(prompts, completions, **kwargs):
+ return [
+ 0 if match_format.search(response) is None else 3.0
+ for response in completions
+ ]
+
+
+class SglangJaxLoRATest(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ VERIFY_UPDATE_PARAMS_VAL = "layers.0.mlp.gate_proj.kernel_lora_a,model.layers.0.mlp.gate_proj.A_buffer"
+ os.environ[VERIFY_UPDATE_PARAMS_KEY] = VERIFY_UPDATE_PARAMS_VAL
+
+ super().setUpClass()
+
+ ## ====================================== Get Dataset ===================================
+ source = "kaggle"
+ cls.dataset = get_dataset(TRAIN_DATA_DIR, "train", source).batch(
+ TRAIN_MICRO_BATCH_SIZE
+ )[:NUM_BATCHES]
+
+ ## ====================================== Get LoRA model =================================
+ download_from_huggingface(repo_id=repo_id, model_path=model_path)
+
+ ref_model, cls.mesh, _ = load_model()
+ rollout_mesh = get_rollout_mesh()
+
+ lora_policy = get_lora_model(ref_model)
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
+ mapping_config = mappings.MappingConfig.build(
+ model=ref_model, backend="sglang_jax"
+ )
+
+ ## ========================== Iniitialize RL cluster and trainer ==========================
+ cluster_config, grpo_config = cls.prepare_configs(
+ cls.mesh, rollout_mesh, mapping_config
+ )
+
+ # RL cluster
+ rl_cluster = rl_cluster_lib.RLCluster(
+ actor=lora_policy,
+ reference=ref_model,
+ tokenizer=tokenizer,
+ cluster_config=cluster_config,
+ )
+
+ # GRPO Trainer
+ cls.grpo_trainer = GRPOLearner(
+ rl_cluster=rl_cluster,
+ reward_fns=[
+ match_format_exactly,
+ ],
+ algo_config=grpo_config,
+ )
+
+ @classmethod
+ def prepare_configs(cls, mesh, rollout_mesh, mapping_config):
+ # Ckpt saving
+ checkpointing_options = ocp.CheckpointManagerOptions(
+ save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
+ )
+
+ # Metrics logger
+ metrics_logging_options = metrics_logger.MetricsLoggerOptions(
+ log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
+ )
+
+ optimizer = optax.adamw(
+ learning_rate=optax.schedules.warmup_cosine_decay_schedule(
+ init_value=0.0,
+ peak_value=LEARNING_RATE,
+ warmup_steps=WARMUP_STEPS,
+ decay_steps=MAX_STEPS,
+ end_value=0.0,
+ ),
+ b1=B1,
+ b2=B2,
+ weight_decay=WEIGHT_DECAY,
+ )
+
+ if MAX_GRAD_NORM is not None:
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
+ optimizer,
+ )
+
+ # Training config
+ cluster_config = rl_cluster_lib.ClusterConfig(
+ role_to_mesh={
+ rl_cluster_lib.Role.ACTOR: mesh,
+ rl_cluster_lib.Role.REFERENCE: mesh,
+ rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
+ },
+ rollout_engine="sglang_jax",
+ offload_to_cpu=False,
+ training_config=rl_cluster_lib.RLTrainingConfig(
+ actor_optimizer=optimizer,
+ eval_every_n_steps=EVAL_EVERY_N_STEPS,
+ max_steps=MAX_STEPS,
+ mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
+ train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
+ metrics_logging_options=metrics_logging_options,
+ checkpoint_root_directory=CKPT_DIR,
+ checkpointing_options=checkpointing_options,
+ ),
+ rollout_config=base_rollout.RolloutConfig(
+ max_tokens_to_generate=TOTAL_GENERATION_STEPS,
+ max_prompt_length=MAX_PROMPT_LENGTH,
+ kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
+ temperature=TEMPERATURE,
+ top_p=TOP_P,
+ top_k=TOP_K,
+ rollout_mapping_config=mapping_config,
+ rollout_sglang_jax_model_version=model_path,
+ rollout_sglang_jax_context_length=2048,
+ rollout_sglang_jax_mem_fraction_static=0.3,
+ rollout_sglang_jax_init_with_random_weights=True,
+ rollout_sglang_jax_disable_radix_cache=True,
+ rollout_sglang_jax_enable_deterministic_sampling=False,
+ rollout_sglang_jax_use_sort_for_toppk_minp=True,
+ rollout_sglang_jax_enable_static_lora=True,
+ rollout_sglang_jax_enable_single_process=True,
+ rollout_sglang_jax_lora_target_modules=["all"],
+ rollout_sglang_jax_max_lora_rank=RANK,
+ rollout_sglang_jax_lora_scaling=ALPHA / RANK,
+ rollout_sglang_jax_precompile_bs_paddings=[8],
+ rollout_sglang_jax_precompile_token_paddings=[2048],
+ ),
+ )
+
+ grpo_config = GRPOConfig(
+ num_generations=NUM_GENERATIONS,
+ num_iterations=NUM_ITERATIONS,
+ beta=BETA,
+ epsilon=EPSILON,
+ )
+
+ return (cluster_config, grpo_config)
+
+ def test_lora(self):
+ with self.mesh:
+ self.grpo_trainer.train(self.dataset)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/tunix/generate/mappings.py b/tunix/generate/mappings.py
index e71887eaf..97e7999cc 100644
--- a/tunix/generate/mappings.py
+++ b/tunix/generate/mappings.py
@@ -47,6 +47,11 @@ def to_hf_transpose_keys(cls, backend: str | None = None):
result = cls.mapping_for(backend).get('to_hf_transpose_keys')
return result or None
+ @classmethod
+ def lora_to_hf_transpose_keys(cls, backend: str | None = None):
+ result = cls.mapping_for(backend).get('lora_to_hf_transpose_keys')
+ return result or None
+
@classmethod
def to_hf_hook_fns(cls, backend: str | None = None):
return cls.mapping_for(backend).get('to_hf_hook_fns')
@@ -66,6 +71,7 @@ class MappingConfig:
lora_to_hf_mappings: Optional[Dict[str, Any]] = None
to_hf_hook_fns: Optional[Dict[str, Any]] = None
to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
+ lora_to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
@classmethod
def build(
@@ -90,6 +96,7 @@ def build(
'lora_to_hf_mappings',
'to_hf_hook_fns',
'to_hf_transpose_keys',
+ 'lora_to_hf_transpose_keys',
)
values: Dict[str, Any] = {}
@@ -116,6 +123,7 @@ def build(
lora_to_hf_mappings=resolved.get('lora_to_hf_mappings'),
to_hf_hook_fns=resolved.get('to_hf_hook_fns'),
to_hf_transpose_keys=resolved.get('to_hf_transpose_keys'),
+ lora_to_hf_transpose_keys=resolved.get('lora_to_hf_transpose_keys'),
)
@classmethod
@@ -143,6 +151,7 @@ def maybe_call(attr: str):
lora_to_hf_mappings=maybe_call('lora_to_hf_mappings'),
to_hf_hook_fns=maybe_call('to_hf_hook_fns'),
to_hf_transpose_keys=maybe_call('to_hf_transpose_keys'),
+ lora_to_hf_transpose_keys=maybe_call('lora_to_hf_transpose_keys'),
)
for key, value in overrides.items():
diff --git a/tunix/generate/sglang_jax_sampler.py b/tunix/generate/sglang_jax_sampler.py
index cfa92da8a..8fd58b755 100644
--- a/tunix/generate/sglang_jax_sampler.py
+++ b/tunix/generate/sglang_jax_sampler.py
@@ -15,39 +15,53 @@
"""Sampler for sglang-jax-style autoregressive decoding using JAX and NNX models."""
import dataclasses
+import logging
import math
import os
+import re
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
-from absl import logging
from flax import nnx
import jax
import jax.numpy as jnp
import jaxtyping
import numpy as np
from sgl_jax.srt.entrypoints.engine import Engine
+from sgl_jax.srt.utils.common_utils import SUPPORTED_LORA_TARGET_MODULES
from tunix.generate import base_sampler
from tunix.generate import mappings
from tunix.generate import utils
import tunix.generate.tokenizer_adapter as tok_adapter
from tunix.rl import reshard
+from tunix.rl.utils import VERIFY_UPDATE_PARAMS_KEY
@dataclasses.dataclass
class SglangJaxConfig:
- model_version: str
- context_length: int
mesh: jax.sharding.Mesh
- mem_fraction_static: float
- init_with_random_weights: bool
- disable_radix_cache: bool
- enable_deterministic_sampling: bool
mapping_config: mappings.MappingConfig
+
+ model_version: str
+ context_length: int
+
+ mem_fraction_static: float = 0.2
+ init_with_random_weights: bool = True
+ disable_radix_cache: bool = True
+ enable_deterministic_sampling: bool = False
# Note: use_sort_for_toppk_minp may be removed in the future. It depends on SGLang-Jax.
use_sort_for_toppk_minp: bool = True
- precompile_bs_paddings: Optional[List] = None
- precompile_token_paddings: Optional[List] = None
- chunked_prefill_size: int = -1
+ enable_static_lora: bool = False
+ enable_single_process: bool = (
+ True # Note: this is required when you run it in pathways.
+ )
+
+ lora_target_modules: Optional[List[str]] = None
+ max_lora_rank: Optional[int] = None
+ lora_scaling: Optional[float] = None
+
+ precompile_token_paddings: Optional[List[int]] = None
+ precompile_bs_paddings: Optional[List[int]] = None
+ chunked_prefill_size: Optional[int] = -1
page_size: int = 64
@@ -76,10 +90,27 @@ def __init__(
self.args = self._sglang_jax_config(config)
self.engine = Engine(**self.args)
- self.mappings = config.mapping_config.to_hf_mappings
+ self.to_hf_key_mappings = config.mapping_config.to_hf_mappings
+ self.to_hf_key_mappings = update_hf_key_mappings_with_lora(
+ self.to_hf_key_mappings,
+ self.args["enable_static_lora"],
+ self.args["lora_target_modules"],
+ )
self.to_hf_transpose_keys = config.mapping_config.to_hf_transpose_keys
self.to_hf_hook_fns = config.mapping_config.to_hf_hook_fns
+ if config.mapping_config.lora_to_hf_mappings:
+ self.to_hf_key_mappings |= config.mapping_config.lora_to_hf_mappings
+
+ if config.mapping_config.lora_to_hf_transpose_keys:
+ self.to_hf_transpose_keys |= (
+ config.mapping_config.lora_to_hf_transpose_keys
+ )
+
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ self._logger.debug(f"{self.to_hf_key_mappings=}")
+
# TODO(b/434969743): Optimize weight sharing between trainer and sglang-jax sampler.
# TODO(b/434975493): Consider Release KV cache on the fly
def update_params(
@@ -91,13 +122,81 @@ def update_params(
new_state = utils.transfer_state_with_mappings(
src_state=updated_weights,
dst_state=self.transformer_state,
- key_mappings=self.mappings,
+ key_mappings=self.to_hf_key_mappings,
transpose_keys=self.to_hf_transpose_keys,
reshard_fn=reshard.reshard_pytree,
+ rollout_engine="sglang_jax",
)
new_model_state_leaves, _ = jax.tree_util.tree_flatten(new_state)
self._model_runner.model_state_leaves = new_model_state_leaves
+ flatten_src_to_tgt_module_name = os.getenv(VERIFY_UPDATE_PARAMS_KEY, None)
+ if flatten_src_to_tgt_module_name is not None:
+ # update state before verification
+ nnx.update(self._model_runner.model, new_state)
+
+ self.verify_update_params(
+ updated_weights,
+ self.transformer_state,
+ flatten_src_to_tgt_module_name,
+ )
+
+ def verify_update_params(
+ self,
+ src_state: nnx.State,
+ tgt_state: nnx.State,
+ flatten_src_to_tgt_module_name: str,
+ ):
+ self._logger.debug(
+ f"[verify_update_params] {flatten_src_to_tgt_module_name} is required"
+ " to verify"
+ )
+ src_tgt = flatten_src_to_tgt_module_name.split(",")
+ assert len(src_tgt) == 2
+ flatten_src_module_name = src_tgt[0]
+ flatten_tgt_module_name = src_tgt[1]
+ flatten_src_state_list = src_state.flat_state()
+ src_value = None
+ for keys, param in flatten_src_state_list:
+ path = ".".join(str(key) for key in keys)
+ if path == flatten_src_module_name:
+ src_value = param.value if hasattr(param, "value") else param
+ if src_value is None:
+ raise ValueError(
+ f"{flatten_src_module_name=} does not exist in src: {src_state=}"
+ )
+
+ flatten_tgt_state_list = tgt_state.flat_state()
+ tgt_value = None
+ for keys, param in flatten_tgt_state_list:
+ path = ".".join(str(key) for key in keys)
+ if path == flatten_tgt_module_name:
+ tgt_value = param.value if hasattr(param, "value") else param
+
+ if tgt_value is None:
+ raise ValueError(
+ f"{flatten_tgt_module_name=} does not exist in tgt: {tgt_state=}"
+ )
+
+ if "lora" in flatten_src_module_name:
+ for r_key in self.to_hf_transpose_keys:
+ if re.match(r_key, flatten_src_module_name):
+ logging.info(
+ "Applying LoRA transpose on %s in verification",
+ flatten_src_module_name,
+ )
+ transposed_src_value = jnp.transpose(
+ src_value[None, :, :], self.to_hf_transpose_keys[r_key]
+ )
+ else:
+ transposed_src_value = src_value
+
+ src_value_np, tgt_value_np = jax.device_get(
+ transposed_src_value
+ ), jax.device_get(tgt_value)
+ if not np.array_equal(src_value_np, tgt_value_np):
+ raise ValueError(f"{src_value_np=} is not equal to {tgt_value_np=}")
+
def load_checkpoint(self, path_or_weights: str | jaxtyping.PyTree):
# TODO(b/434741253): Consider support orbax checkpoint loading
if isinstance(path_or_weights, jaxtyping.PyTree):
@@ -114,21 +213,50 @@ def _sglang_jax_config(self, config: SglangJaxConfig):
args = {}
args["model_path"] = config.model_version
args["context_length"] = config.context_length
- args["tp_size"] = self._find_tp_size(config.mesh)
- args["device_indexes"] = config.mesh.device_ids.flatten().tolist()
+
args["mem_fraction_static"] = config.mem_fraction_static
- args["enable_single_process"] = True
- if config.disable_radix_cache:
- args["disable_radix_cache"] = True
- if config.enable_deterministic_sampling:
- args["enable_deterministic_sampling"] = True
if config.init_with_random_weights:
args["load_format"] = "dummy"
+ args["disable_radix_cache"] = config.disable_radix_cache
+ args["enable_deterministic_sampling"] = config.enable_deterministic_sampling
args["use_sort_for_toppk_minp"] = config.use_sort_for_toppk_minp
- args["precompile_bs_paddings"] = config.precompile_bs_paddings
- args["precompile_token_paddings"] = config.precompile_token_paddings
+ args["enable_static_lora"] = config.enable_static_lora
+ args["enable_single_process"] = config.enable_single_process
+
+ if config.enable_static_lora:
+ assert (
+ config.lora_target_modules is not None
+ and config.max_lora_rank is not None
+ and config.lora_scaling is not None
+ )
+ # check whether the lora_target_modules are valid
+ if config.lora_target_modules == ["all"]:
+ config.lora_target_modules = SUPPORTED_LORA_TARGET_MODULES
+ else:
+ for module in config.lora_target_modules:
+ if module not in SUPPORTED_LORA_TARGET_MODULES:
+ raise ValueError(
+ f"{module} in lora_target_modules does not exist in"
+ f" {SUPPORTED_LORA_TARGET_MODULES}"
+ )
+ args["lora_target_modules"] = config.lora_target_modules
+ args["max_lora_rank"] = config.max_lora_rank
+ args["max_loras_per_batch"] = 1
+ args["lora_scaling"] = config.lora_scaling
+
+ if config.precompile_token_paddings is not None:
+ assert isinstance(config.precompile_token_paddings, List)
+ args["precompile_token_paddings"] = config.precompile_token_paddings
+ if config.precompile_bs_paddings is not None:
+ assert isinstance(config.precompile_bs_paddings, List)
+ args["precompile_bs_paddings"] = config.precompile_bs_paddings
args["chunked_prefill_size"] = config.chunked_prefill_size
+
+ # default arguments which is derived from known configuration or to tune.
args["page_size"] = config.page_size
+ args["tp_size"] = self._find_tp_size(config.mesh)
+ args["device_indexes"] = config.mesh.device_ids.flatten().tolist()
+
return args
@property
@@ -256,3 +384,29 @@ def __call__(
padded_prompt_tokens=all_input_ids,
logprobs=None,
)
+
+
+def update_hf_key_mappings_with_lora(
+ mappings: Optional[Dict[str, Any]] = None,
+ enable_static_lora: bool = False,
+ lora_target_modules: Optional[List] = None,
+):
+ if (
+ mappings is None
+ or not enable_static_lora
+ or lora_target_modules is None
+ or len(lora_target_modules) == 0
+ ):
+ return mappings
+
+ # Note: SGLangJax implements the LoRA through wraping the base_layer, so the value in mappings needs to be updated.
+ # From 'model.layers.*.mlp.gate_proj.weight' to 'model.layers.*.mlp.gate_proj.base_layer.weight'
+ for module in lora_target_modules:
+ for src_path, tgt_params in mappings.items():
+ if module in src_path:
+ tgt_path, sharding = tgt_params
+ keys = tgt_path.split(".")
+ new_tgt_path = ".".join(keys[:-1]) + ".base_layer." + keys[-1]
+ mappings[src_path] = (new_tgt_path, sharding)
+ break
+ return mappings
diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py
index 7e4bfde80..7669e631a 100644
--- a/tunix/generate/utils.py
+++ b/tunix/generate/utils.py
@@ -464,6 +464,7 @@ def _apply_transpose(
val: jnp.ndarray,
src_key: str,
transpose_keys: Optional[Dict[str, Tuple[int, ...]]],
+ rollout_engine: Optional[str],
) -> jnp.ndarray:
"""Apply transpose operation if configured for this key."""
if not transpose_keys:
@@ -479,6 +480,16 @@ def _apply_transpose(
if target_key != '':
logging.debug('Applying transpose on %s', src_key)
return jnp.transpose(val, transpose_keys[target_key])
+
+ # For LoRA
+ # Note: The following codes takes effect in SGLangJAx rollout, and may not take effect in other rollout engine.
+
+ if rollout_engine == 'sglang_jax' and 'lora' in all_key:
+ for r_key in transpose_keys:
+ if re.compile(rf'{r_key}').match(all_key):
+ logging.debug('Applying LoRA transpose on %s', src_key)
+ return jnp.transpose(val[None, :, :], transpose_keys[r_key])
+
return val
@@ -603,7 +614,6 @@ def _align_shape(
def _apply_dtype_cast(
val: jnp.ndarray, tgt_dtype: jnp.dtype, src_key: str
) -> jnp.ndarray:
-
if val.dtype != tgt_dtype:
logging.warning(
'Type mismatch on %s: %s -> %s',
@@ -622,6 +632,7 @@ def transfer_state_with_mappings(
key_mapping_hook_fns=None,
transpose_keys=None,
reshard_fn=None,
+ rollout_engine=None,
):
"""Transfer state using mappings, with optional transpose and shard logic.
@@ -643,8 +654,10 @@ def transfer_state_with_mappings(
"""
# Get flat target state
tgt_flat_list = dst_state.flat_state()
+
# Build sharding dictionary if resharding is needed
sharding_dict = None
+
if reshard_fn:
sharding_dict = {
key: (
@@ -667,7 +680,7 @@ def transfer_state_with_mappings(
tgt_param,
) in unscanned_src_to_tgt_flat.items():
# Apply transpose if configured
- val = _apply_transpose(val, flat_src_key, transpose_keys)
+ val = _apply_transpose(val, flat_src_key, transpose_keys, rollout_engine)
# Apply optional hook function
if key_mapping_hook_fns and flat_src_key in key_mapping_hook_fns:
@@ -725,6 +738,7 @@ def transfer_state_directly(
dst_state: The destination state to transfer to.
reshard_fn: A function to shard the values.
"""
+
def safe_has_key(obj: Mapping[str, Any], key: str) -> bool:
if isinstance(obj, dict):
return key in obj
@@ -732,12 +746,16 @@ def safe_has_key(obj: Mapping[str, Any], key: str) -> bool:
return hasattr(obj, key)
# Unwrap Source (Remove 'base' wrapper from MaxText)
- if isinstance(src_state, (dict, nnx.State, nnx.Dict)) and safe_has_key(src_state, 'base'):
+ if isinstance(src_state, (dict, nnx.State, nnx.Dict)) and safe_has_key(
+ src_state, 'base'
+ ):
logging.info("Unwrapping 'base' key from source state.")
src_state = src_state['base']
# Unwrap Target (Remove nested 'model' wrappers from vLLM)
- while isinstance(dst_state, (dict, nnx.State, nnx.Dict)) and safe_has_key(dst_state, 'model'):
+ while isinstance(dst_state, (dict, nnx.State, nnx.Dict)) and safe_has_key(
+ dst_state, 'model'
+ ):
logging.info("Unwrapping nested 'model' key from target state.")
dst_state = dst_state['model']
@@ -763,7 +781,9 @@ def to_pure_spec(node: Any) -> Any:
return node
# Helper: Intersect Trees (Handle KVCache/RNG mismatches)
- def intersect_trees(src: Any, tgt_spec: Any, path: str = "") -> Tuple[Any, Any]:
+ def intersect_trees(
+ src: Any, tgt_spec: Any, path: str = ''
+ ) -> Tuple[Any, Any]:
# Stop recursion if we hit a leaf (non-dict)
if not isinstance(src, dict) or not isinstance(tgt_spec, dict):
return src, tgt_spec
@@ -786,7 +806,7 @@ def intersect_trees(src: Any, tgt_spec: Any, path: str = "") -> Tuple[Any, Any]:
filtered_tgt = {}
for k in common_keys:
- new_path = f"{path}/{k}" if path else k
+ new_path = f'{path}/{k}' if path else k
s_val, t_val = intersect_trees(src[k], tgt_spec[k], new_path)
filtered_src[k] = s_val
filtered_tgt[k] = t_val
diff --git a/tunix/models/llama3/mapping_sglang_jax.py b/tunix/models/llama3/mapping_sglang_jax.py
index fe22f764c..3dd6e47bd 100644
--- a/tunix/models/llama3/mapping_sglang_jax.py
+++ b/tunix/models/llama3/mapping_sglang_jax.py
@@ -5,16 +5,18 @@
import os
from typing import Any, Dict, Tuple
+from tunix.utils.env_utils import SGLANG_JAX_TP_AXIS_NAME
+
Sharding = Tuple[str | None, ...]
MappingEntry = Tuple[str, Sharding]
def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
return {
- 'lm_head.w': ('lm_head.embedding', (None, 'model')),
+ 'lm_head.w': ('lm_head.embedding', (None, SGLANG_JAX_TP_AXIS_NAME)),
'embedder.input_embedding': (
'model.embed_tokens.embedding',
- ('model', None),
+ (SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.input_layernorm.w': (
'model.layers.*.input_layernorm.scale',
@@ -22,15 +24,15 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
),
'layers.*.mlp.down_proj.kernel': (
'model.layers.*.mlp.down_proj.weight',
- ('model', None),
+ (SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.mlp.gate_proj.kernel': (
'model.layers.*.mlp.gate_proj.weight',
- (None, 'model'),
+ (None, SGLANG_JAX_TP_AXIS_NAME),
),
'layers.*.mlp.up_proj.kernel': (
'model.layers.*.mlp.up_proj.weight',
- (None, 'model'),
+ (None, SGLANG_JAX_TP_AXIS_NAME),
),
'layers.*.post_attention_layernorm.w': (
'model.layers.*.post_attention_layernorm.scale',
@@ -38,19 +40,19 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
),
'layers.*.attn.k_proj.w': (
'model.layers.*.self_attn.k_proj.weight',
- (None, 'model', None),
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.attn.o_proj.w': (
'model.layers.*.self_attn.o_proj.weight',
- ('model', None, None),
+ (SGLANG_JAX_TP_AXIS_NAME, None, None),
),
'layers.*.attn.q_proj.w': (
'model.layers.*.self_attn.q_proj.weight',
- (None, 'model', None),
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.attn.v_proj.w': (
'model.layers.*.self_attn.v_proj.weight',
- (None, 'model', None),
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
),
'final_norm.w': ('model.norm.scale', (None,)),
}
@@ -58,7 +60,64 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None:
"""The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend"""
- return None
+ return {
+ 'layers.*.mlp.gate_proj.kernel_lora_a': (
+ 'model.layers.*.mlp.gate_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.mlp.gate_proj.kernel_lora_b': (
+ 'model.layers.*.mlp.gate_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.mlp.up_proj.kernel_lora_a': (
+ 'model.layers.*.mlp.up_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.mlp.up_proj.kernel_lora_b': (
+ 'model.layers.*.mlp.up_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.mlp.down_proj.kernel_lora_a': (
+ 'model.layers.*.mlp.down_proj.A_buffer',
+ (None, None, SGLANG_JAX_TP_AXIS_NAME),
+ ),
+ 'layers.*.mlp.down_proj.kernel_lora_b': (
+ 'model.layers.*.mlp.down_proj.B_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.q_proj.w_lora_a': (
+ 'model.layers.*.self_attn.q_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.q_proj.w_lora_b': (
+ 'model.layers.*.self_attn.q_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.attn.k_proj.w_lora_a': (
+ 'model.layers.*.self_attn.k_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.k_proj.w_lora_b': (
+ 'model.layers.*.self_attn.k_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.attn.v_proj.w_lora_a': (
+ 'model.layers.*.self_attn.v_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.v_proj.w_lora_b': (
+ 'model.layers.*.self_attn.v_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.attn.o_proj.w_lora_a': (
+ 'model.layers.*.self_attn.o_proj.A_buffer',
+ (None, None, SGLANG_JAX_TP_AXIS_NAME),
+ ),
+ 'layers.*.attn.o_proj.w_lora_b': (
+ 'model.layers.*.self_attn.o_proj.B_buffer',
+ (None, None, None),
+ ),
+ }
def _to_sglang_jax_transpose_keys():
@@ -67,6 +126,42 @@ def _to_sglang_jax_transpose_keys():
}
+def _to_sglang_jax_lora_transpose_keys():
+ """
+ Tunix -> SGLangJax:
+ gate_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ gate_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank)
+ up_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ up_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank)
+ down_lora_a: (intermediate_size, max_lora_rank) -> (1, max_lora_rank, intermediate_size)
+ down_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank)
+ q_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ q_lora_b: (max_lora_rank, num_attention_heads, head_dim) -> (1, hidden_size, max_lora_rank)
+ k_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ k_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank)
+ v_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ v_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank)
+ o_lora_a: (num_attention_heads, head_dim, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ o_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank)
+ """
+ return {
+ 'layers.*.mlp.gate_proj.kernel_lora_a': (0, 2, 1),
+ 'layers.*.mlp.gate_proj.kernel_lora_b': (0, 2, 1),
+ 'layers.*.mlp.up_proj.kernel_lora_a': (0, 2, 1),
+ 'layers.*.mlp.up_proj.kernel_lora_b': (0, 2, 1),
+ 'layers.*.mlp.down_proj.kernel_lora_a': (0, 2, 1),
+ 'layers.*.mlp.down_proj.kernel_lora_b': (0, 2, 1),
+ 'layers.*.attn.q_proj.w_lora_a': (0, 2, 1),
+ 'layers.*.attn.q_proj.w_lora_b': (0, 2, 3, 1),
+ 'layers.*.attn.k_proj.w_lora_a': (0, 2, 1),
+ 'layers.*.attn.k_proj.w_lora_b': (0, 2, 3, 1),
+ 'layers.*.attn.v_proj.w_lora_a': (0, 2, 1),
+ 'layers.*.attn.v_proj.w_lora_b': (0, 2, 3, 1),
+ 'layers.*.attn.o_proj.w_lora_a': (0, 3, 2, 1),
+ 'layers.*.attn.o_proj.w_lora_b': (0, 2, 1),
+ }
+
+
def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None:
"""Additional parameter manipulation hook between Tunix vanilla model and Sglang Jax backend"""
return None
@@ -76,6 +171,7 @@ def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None:
'to_hf_mappings': _to_sglang_jax_mappings(),
'lora_to_hf_mappings': _lora_to_sglang_jax_mappings(),
'to_hf_transpose_keys': _to_sglang_jax_transpose_keys(),
+ 'lora_to_hf_transpose_keys': _to_sglang_jax_lora_transpose_keys(),
'to_hf_hook_fns': _to_sglang_jax_hook_fns(),
}
diff --git a/tunix/models/qwen2/mapping_sglang_jax.py b/tunix/models/qwen2/mapping_sglang_jax.py
index 5b2ee26e5..71f96441b 100644
--- a/tunix/models/qwen2/mapping_sglang_jax.py
+++ b/tunix/models/qwen2/mapping_sglang_jax.py
@@ -5,16 +5,18 @@
import os
from typing import Any, Dict, Tuple
+from tunix.utils.env_utils import SGLANG_JAX_TP_AXIS_NAME
+
Sharding = Tuple[str | None, ...]
MappingEntry = Tuple[str, Sharding]
def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
return {
- 'lm_head.w': ('lm_head.embedding', (None, 'model')),
+ 'lm_head.w': ('lm_head.embedding', (None, SGLANG_JAX_TP_AXIS_NAME)),
'embedder.input_embedding': (
'model.embed_tokens.embedding',
- ('model', None),
+ (SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.input_layernorm.w': (
'model.layers.*.input_layernorm.scale',
@@ -22,15 +24,15 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
),
'layers.*.mlp.down_proj.kernel': (
'model.layers.*.mlp.down_proj.weight',
- ('model', None),
+ (SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.mlp.gate_proj.kernel': (
'model.layers.*.mlp.gate_proj.weight',
- (None, 'model'),
+ (None, SGLANG_JAX_TP_AXIS_NAME),
),
'layers.*.mlp.up_proj.kernel': (
'model.layers.*.mlp.up_proj.weight',
- (None, 'model'),
+ (None, SGLANG_JAX_TP_AXIS_NAME),
),
'layers.*.post_attention_layernorm.w': (
'model.layers.*.post_attention_layernorm.scale',
@@ -38,7 +40,7 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
),
'layers.*.attn.k_proj.w': (
'model.layers.*.self_attn.k_proj.weight',
- (None, 'model', None),
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.attn.k_bias': (
'model.layers.*.self_attn.k_proj.bias',
@@ -46,11 +48,11 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
),
'layers.*.attn.o_proj.w': (
'model.layers.*.self_attn.o_proj.weight',
- ('model', None, None),
+ (SGLANG_JAX_TP_AXIS_NAME, None, None),
),
'layers.*.attn.q_proj.w': (
'model.layers.*.self_attn.q_proj.weight',
- (None, 'model', None),
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.attn.q_bias': (
'model.layers.*.self_attn.q_proj.bias',
@@ -58,7 +60,7 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
),
'layers.*.attn.v_proj.w': (
'model.layers.*.self_attn.v_proj.weight',
- (None, 'model', None),
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
),
'layers.*.attn.v_bias': (
'model.layers.*.self_attn.v_proj.bias',
@@ -70,7 +72,64 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:
def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None:
"""The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend"""
- return None
+ return {
+ 'layers.*.mlp.gate_proj.kernel_lora_a': (
+ 'model.layers.*.mlp.gate_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.mlp.gate_proj.kernel_lora_b': (
+ 'model.layers.*.mlp.gate_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.mlp.up_proj.kernel_lora_a': (
+ 'model.layers.*.mlp.up_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.mlp.up_proj.kernel_lora_b': (
+ 'model.layers.*.mlp.up_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.mlp.down_proj.kernel_lora_a': (
+ 'model.layers.*.mlp.down_proj.A_buffer',
+ (None, None, SGLANG_JAX_TP_AXIS_NAME),
+ ),
+ 'layers.*.mlp.down_proj.kernel_lora_b': (
+ 'model.layers.*.mlp.down_proj.B_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.q_proj.w_lora_a': (
+ 'model.layers.*.self_attn.q_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.q_proj.w_lora_b': (
+ 'model.layers.*.self_attn.q_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.attn.k_proj.w_lora_a': (
+ 'model.layers.*.self_attn.k_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.k_proj.w_lora_b': (
+ 'model.layers.*.self_attn.k_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.attn.v_proj.w_lora_a': (
+ 'model.layers.*.self_attn.v_proj.A_buffer',
+ (None, None, None),
+ ),
+ 'layers.*.attn.v_proj.w_lora_b': (
+ 'model.layers.*.self_attn.v_proj.B_buffer',
+ (None, SGLANG_JAX_TP_AXIS_NAME, None),
+ ),
+ 'layers.*.attn.o_proj.w_lora_a': (
+ 'model.layers.*.self_attn.o_proj.A_buffer',
+ (None, None, SGLANG_JAX_TP_AXIS_NAME),
+ ),
+ 'layers.*.attn.o_proj.w_lora_b': (
+ 'model.layers.*.self_attn.o_proj.B_buffer',
+ (None, None, None),
+ ),
+ }
def _to_sglang_jax_transpose_keys():
@@ -79,6 +138,42 @@ def _to_sglang_jax_transpose_keys():
}
+def _to_sglang_jax_lora_transpose_keys():
+ """
+ Tunix -> SGLangJax:
+ gate_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ gate_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank)
+ up_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ up_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank)
+ down_lora_a: (intermediate_size, max_lora_rank) -> (1, max_lora_rank, intermediate_size)
+ down_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank)
+ q_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ q_lora_b: (max_lora_rank, num_attention_heads, head_dim) -> (1, hidden_size, max_lora_rank)
+ k_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ k_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank)
+ v_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ v_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank)
+ o_lora_a: (num_attention_heads, head_dim, max_lora_rank) -> (1, max_lora_rank, hidden_size)
+ o_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank)
+ """
+ return {
+ 'layers.*.mlp.gate_proj.kernel_lora_a': (0, 2, 1),
+ 'layers.*.mlp.gate_proj.kernel_lora_b': (0, 2, 1),
+ 'layers.*.mlp.up_proj.kernel_lora_a': (0, 2, 1),
+ 'layers.*.mlp.up_proj.kernel_lora_b': (0, 2, 1),
+ 'layers.*.mlp.down_proj.kernel_lora_a': (0, 2, 1),
+ 'layers.*.mlp.down_proj.kernel_lora_b': (0, 2, 1),
+ 'layers.*.attn.q_proj.w_lora_a': (0, 2, 1),
+ 'layers.*.attn.q_proj.w_lora_b': (0, 2, 3, 1),
+ 'layers.*.attn.k_proj.w_lora_a': (0, 2, 1),
+ 'layers.*.attn.k_proj.w_lora_b': (0, 2, 3, 1),
+ 'layers.*.attn.v_proj.w_lora_a': (0, 2, 1),
+ 'layers.*.attn.v_proj.w_lora_b': (0, 2, 3, 1),
+ 'layers.*.attn.o_proj.w_lora_a': (0, 3, 2, 1),
+ 'layers.*.attn.o_proj.w_lora_b': (0, 2, 1),
+ }
+
+
def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None:
"""Additional parameter manipulation hook between Tunix vanilla model and Sglang Jax backend"""
return None
@@ -88,6 +183,7 @@ def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None:
'to_hf_mappings': _to_sglang_jax_mappings(),
'lora_to_hf_mappings': _lora_to_sglang_jax_mappings(),
'to_hf_transpose_keys': _to_sglang_jax_transpose_keys(),
+ 'lora_to_hf_transpose_keys': _to_sglang_jax_lora_transpose_keys(),
'to_hf_hook_fns': _to_sglang_jax_hook_fns(),
}
diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py
index d1371e20d..b911a89cd 100644
--- a/tunix/rl/rl_learner.py
+++ b/tunix/rl/rl_learner.py
@@ -20,8 +20,7 @@
from concurrent import futures
import itertools
import math
-from typing import Any, Callable, Dict, Iterable, Iterator, List, Sequence
-from typing import Generic, TypeVar
+from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Sequence, TypeVar
from absl import logging
import jax
diff --git a/tunix/rl/rollout/base_rollout.py b/tunix/rl/rollout/base_rollout.py
index 5e9c9d0c2..7f095c9a0 100644
--- a/tunix/rl/rollout/base_rollout.py
+++ b/tunix/rl/rollout/base_rollout.py
@@ -162,15 +162,31 @@ class RolloutConfig:
# Whether to enable deterministic sampling for SG-Lang JAX rollout engine.
rollout_sglang_jax_enable_deterministic_sampling: bool = False
- # List of token buckets for jax jit
- rollout_sglang_jax_precompile_token_paddings: Optional[List] = None
+ # Whether to use sort or mask implementation in sampler, sort has better evaluation result.
+ rollout_sglang_jax_use_sort_for_toppk_minp: bool = True
- # List of batch sizes buckets for jax jit
- rollout_sglang_jax_precompile_bs_paddings: Optional[List] = None
+ # Whether to use lora
+ rollout_sglang_jax_enable_static_lora: bool = False
- # The maximum number of tokens in a chunk for the chunked prefill.
- # Setting this to -1 means disabling chunked prefill.
- rollout_sglang_jax_chunked_prefill_size: int = -1
+ # Whether to use single controller mode, single controller mode is required in pathways
+ rollout_sglang_jax_enable_single_process: bool = True
+
+ # Specify the modules which are required to use lora
+ rollout_sglang_jax_lora_target_modules: Optional[List[str]] = None
+
+ # Specify the lora RANK
+ rollout_sglang_jax_max_lora_rank: Optional[int] = None
+
+ rollout_sglang_jax_lora_scaling: Optional[float] = None
+
+ # Specify the paddings for batch_size
+ rollout_sglang_jax_precompile_bs_paddings: Optional[List[int]] = None
+
+ # Specify the paddings for tokens which is used in prefll
+ rollout_sglang_jax_precompile_token_paddings: Optional[List[int]] = None
+
+ # Specify the the maximum number of tokens in a chunk for the chunked prefill
+ rollout_sglang_jax_chunked_prefill_size: Optional[int] = -1
# The number of tokens in a page
rollout_sglang_jax_page_size: int = 64
diff --git a/tunix/rl/rollout/sglang_jax_rollout.py b/tunix/rl/rollout/sglang_jax_rollout.py
index 8850e6466..26471921b 100644
--- a/tunix/rl/rollout/sglang_jax_rollout.py
+++ b/tunix/rl/rollout/sglang_jax_rollout.py
@@ -45,13 +45,19 @@ def __init__(
tokenizer=tokenizer,
config=sglang_jax_sampler.SglangJaxConfig(
mesh=mesh,
- context_length=rollout_config.rollout_sglang_jax_context_length,
+ mapping_config=mapping_config,
model_version=rollout_config.rollout_sglang_jax_model_version,
+ context_length=rollout_config.rollout_sglang_jax_context_length,
mem_fraction_static=rollout_config.rollout_sglang_jax_mem_fraction_static,
init_with_random_weights=rollout_config.rollout_sglang_jax_init_with_random_weights,
disable_radix_cache=rollout_config.rollout_sglang_jax_disable_radix_cache,
enable_deterministic_sampling=rollout_config.rollout_sglang_jax_enable_deterministic_sampling,
- mapping_config=mapping_config,
+ use_sort_for_toppk_minp=rollout_config.rollout_sglang_jax_use_sort_for_toppk_minp,
+ enable_static_lora=rollout_config.rollout_sglang_jax_enable_static_lora,
+ enable_single_process=rollout_config.rollout_sglang_jax_enable_single_process,
+ lora_target_modules=rollout_config.rollout_sglang_jax_lora_target_modules,
+ max_lora_rank=rollout_config.rollout_sglang_jax_max_lora_rank,
+ lora_scaling=rollout_config.rollout_sglang_jax_lora_scaling,
precompile_bs_paddings=rollout_config.rollout_sglang_jax_precompile_bs_paddings,
precompile_token_paddings=rollout_config.rollout_sglang_jax_precompile_token_paddings,
chunked_prefill_size=rollout_config.rollout_sglang_jax_chunked_prefill_size,
@@ -119,3 +125,6 @@ def eos_id(self) -> int:
def model(self) -> nnx.Module:
return self._sampler.transformer
+
+ def model_state(self):
+ return self._sampler.transformer_state
diff --git a/tunix/rl/utils.py b/tunix/rl/utils.py
index 1888368f7..f2e6fa1c2 100644
--- a/tunix/rl/utils.py
+++ b/tunix/rl/utils.py
@@ -258,3 +258,6 @@ def get_partition_spec(
return sharding.spec
else:
return jax.sharding.PartitionSpec()
+
+
+VERIFY_UPDATE_PARAMS_KEY = "VERIFY_UPDATE_PARAMS_SRC_TO_TGT_MODULE_NAME"
diff --git a/tunix/utils/env_utils.py b/tunix/utils/env_utils.py
index 1a058a02e..b49f05af4 100644
--- a/tunix/utils/env_utils.py
+++ b/tunix/utils/env_utils.py
@@ -14,6 +14,8 @@
"""Environment utils."""
+import os
+
import flax
@@ -31,3 +33,6 @@ def is_internal_env():
return True
except ImportError:
return False
+
+
+SGLANG_JAX_TP_AXIS_NAME = os.getenv('SGLANG_JAX_TP_AXIS_NAME', 'tensor')