diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 9c753a758..4b84aec7f 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -98,7 +98,8 @@ jobs: --ignore=tests/generate/vllm_sampler_test.py \ --ignore=tests/generate/vllm_driver_test.py \ --ignore=tests/generate/tokenizer_adapter_test.py \ - --ignore=tests/generate/sglang_jax_sampler_test.py + --ignore=tests/generate/sglang_jax_sampler_test.py \ + --ignore=tests/generate/sglang_jax_lora_test.py - name: Run tunix SFT tests run: | @@ -208,7 +209,7 @@ jobs: EOF apt-get update; apt-get install -y less - cd tunix && python tests/generate/sglang_jax_sampler_test.py + cd tunix && python tests/generate/sglang_jax_sampler_test.py && python tests/generate/sglang_jax_lora_test.py - name: Run tunix SFT integration tests env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/scripts/grpo_demo_llama3_qwen2.py b/scripts/grpo_demo_llama3_qwen2.py index fc9acbd02..e1cb62b23 100644 --- a/scripts/grpo_demo_llama3_qwen2.py +++ b/scripts/grpo_demo_llama3_qwen2.py @@ -42,6 +42,7 @@ import transformers from tunix.cli.utils import data as data_lib from tunix.examples.data import math_dataset +from tunix.generate import mappings from tunix.models.llama3 import model as llama_lib from tunix.models.llama3 import params as llama_params from tunix.models.qwen2 import model as qwen2_lib @@ -244,6 +245,34 @@ required=False, help="Name of dataset, required when data_source is tfds", ) +parser.add_argument( + "--enable-lora", + action="store_true", + default=False, + required=False, + help="Enable LoRA.", +) +parser.add_argument( + "--lora-rank", + type=int, + default=64, + required=False, + help="Rank of LoRA.", +) +parser.add_argument( + "--lora-alpha", + type=float, + default=64.0, + required=False, + help="Alpha of LoRA.", +) +parser.add_argument( + "--lora-target-modules", + nargs="+", + type=str, + default=None, + help="List of target modules to apply LoRA", +) # Parse arguments @@ -292,9 +321,14 @@ def validata_args(): SEED = 42 # ====== LoRA ====== -ENABLE_LORA = False -RANK = 64 -ALPHA = 64.0 +ENABLE_LORA = args.enable_lora +RANK = args.lora_rank +ALPHA = args.lora_alpha +LORA_TARGET_MODULES = args.lora_target_modules +if ENABLE_LORA and LORA_TARGET_MODULES is None: + raise ValueError( + f"{LORA_TARGET_MODULES} can not be None when LoRA is enabled!" + ) # ====== Sharding ====== if "Qwen2.5-0.5B-Instruct" in args.model_version: @@ -684,7 +718,7 @@ def get_lora_model(base_model, model_mesh=None): # Policy model # TODO(b/434959964): Supports lora in vLLM Jax backend if ENABLE_LORA: - training_model = get_lora_model(ref_model, model_mesh=mesh) + training_model = get_lora_model(ref_model) training_mesh = ref_mesh else: training_model, training_mesh, _ = get_model( @@ -1040,13 +1074,16 @@ def evaluate( def get_rollout_config(engine: str) -> base_rollout.RolloutConfig: if engine == "sglang_jax": - return base_rollout.RolloutConfig( + config = base_rollout.RolloutConfig( max_tokens_to_generate=TOTAL_GENERATION_STEPS, max_prompt_length=MAX_PROMPT_LENGTH, kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256, temperature=TEMPERATURE, top_p=TOP_P, top_k=TOP_K, + rollout_mapping_config=mappings.MappingConfig.build( + model=ref_model, backend="sglang_jax" + ), rollout_sglang_jax_model_version=SGLANGJAX_MODEL_VERSION, rollout_sglang_jax_mem_fraction_static=0.2, rollout_sglang_jax_init_with_random_weights=True, @@ -1057,6 +1094,13 @@ def get_rollout_config(engine: str) -> base_rollout.RolloutConfig: rollout_sglang_jax_chunked_prefill_size=2048, rollout_sglang_jax_page_size=64, ) + if ENABLE_LORA: + config.rollout_sglang_jax_enable_static_lora = True + config.rollout_sglang_jax_lora_target_modules = LORA_TARGET_MODULES + config.rollout_sglang_jax_max_lora_rank = RANK + config.rollout_sglang_jax_lora_scaling = ALPHA / RANK + + return config return base_rollout.RolloutConfig( max_tokens_to_generate=TOTAL_GENERATION_STEPS, diff --git a/tests/generate/sglang_jax_lora_test.py b/tests/generate/sglang_jax_lora_test.py new file mode 100644 index 000000000..516399894 --- /dev/null +++ b/tests/generate/sglang_jax_lora_test.py @@ -0,0 +1,371 @@ +""" +Note: This test is based on scripts/grpo_demo_sglang_jax_rollout.py. +For the meanings of constants, please refer to the above file. +""" + +import csv +import os +from pathlib import Path +import re +import shutil +from absl.testing import absltest + +import grain +import huggingface_hub +import jax +import kagglehub +import optax +from orbax import checkpoint as ocp +import qwix +import transformers +from tunix.generate import mappings +from tunix.models.llama3 import model as llama_lib +from tunix.models.llama3 import params as llama3_params_lib +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.grpo.grpo_learner import GRPOConfig +from tunix.rl.grpo.grpo_learner import GRPOLearner +from tunix.rl.rollout import base_rollout +from tunix.rl.utils import VERIFY_UPDATE_PARAMS_KEY +from tunix.sft import metrics_logger + + +############################################# CONSTANTS ########################################### + +TRAIN_DATA_DIR = "./data/train" +TRAIN_FRACTION = 1.0 +RANK = 64 +ALPHA = 64.0 +TOTAL_TPU_TO_USE = jax.device_count() +MESH = [ + ( + 1, + TOTAL_TPU_TO_USE, + ), + ("fsdp", "tp"), +] +MAX_PROMPT_LENGTH = 256 +TOTAL_GENERATION_STEPS = 1024 +TEMPERATURE = 0.9 +TOP_P = 1.0 +TOP_K = 50 +NUM_GENERATIONS = 2 +NUM_ITERATIONS = 1 +BETA = 0.08 +EPSILON = 0.2 +TRAIN_MICRO_BATCH_SIZE = 1 +NUM_BATCHES = 2 +NUM_TEST_BATCHES = 2 +EVAL_EVERY_N_STEPS = 5 +NUM_EPOCHS = 1 +MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS) +LEARNING_RATE = 3e-6 +B1 = 0.9 +B2 = 0.99 +WEIGHT_DECAY = 0.1 +WARMUP_STEPS = 0.1 * MAX_STEPS +MAX_GRAD_NORM = 0.1 +INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/" +CKPT_DIR = "/tmp/content/ckpts/" +SAVE_INTERVAL_STEPS = 500 +MAX_TO_KEEP = 4 + +############################################# PROMPTS ########################################### + +reasoning_start = "" +reasoning_end = "" +solution_start = "" +solution_end = "" + + +SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \ +provide your reasoning. Place it between {reasoning_start} and \ +{reasoning_end}. Then, provide the final answer (i.e., just one numerical \ +value) between {solution_start} and {solution_end}.""" + +TEMPLATE = """user +{system_prompt} + +{question} +model""" + + +def extract_hash_answer(text: str) -> str | None: + if "####" not in text: + return None + return text.split("####")[1].strip() + + +def download_kaggle_dataset(target_dir="./data/gsm8k"): + os.makedirs(target_dir, exist_ok=True) + src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a") + src = Path(src) + dst = Path(target_dir) + + for csv_file in src.glob("*.csv"): # match all CSV files + shutil.copy2(csv_file, dst / csv_file.name) + print(f"Copied {csv_file.name} → {dst/csv_file.name}") + return target_dir + + +def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset: + # Download data + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + kaggle_dir = download_kaggle_dataset(data_dir) + file_name = "main_" + split + ".csv" + csv_path = os.path.join(kaggle_dir, file_name) # adjust filename if needed + + data = [] + with open(csv_path, newline="", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + data.append({ + "question": row["question"], + "answer": row["answer"], + }) + + def _as_text(v): + return v if isinstance(v, str) else v.decode("utf-8") + + dataset = ( + grain.MapDataset.source(data) + .shuffle(seed=42) + .map( + lambda x: { + # passed to model forward pass + "prompts": model_tokenizer.apply_chat_template( + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": x["question"]}, + ], + tokenize=False, + add_generation_prompt=True, + ), + # passed to reward functions + "question": _as_text(x["question"]), + # passed to reward functions + "answer": extract_hash_answer(_as_text(x["answer"])), + } + ) + ) + return dataset + + +def download_from_huggingface(repo_id: str, model_path: str): + """Download checkpoint files from huggingface.""" + print("Make sure you logged in to the huggingface cli.") + all_files = huggingface_hub.list_repo_files(repo_id) + filtered_files = [f for f in all_files if not f.startswith("original/")] + + for filename in filtered_files: + huggingface_hub.hf_hub_download( + repo_id=repo_id, filename=filename, local_dir=model_path + ) + print(f"Downloaded {filtered_files} to: {model_path}") + + +def load_model(): + model_config = llama_lib.ModelConfig.llama3p2_3b() + + mesh = jax.make_mesh( + *MESH, + devices=jax.devices()[:TOTAL_TPU_TO_USE], + axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]), + ) + model = llama3_params_lib.create_model_from_safe_tensors( + model_path, model_config, mesh + ) + return model, mesh, model_config + + +def get_rollout_mesh(): + mesh = jax.make_mesh( + *MESH, + devices=jax.devices()[-TOTAL_TPU_TO_USE:], + axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]), + ) + return mesh + + +def get_lora_model(base_model): + lora_provider = qwix.LoraProvider( + module_path=".*gate_proj", + rank=RANK, + alpha=ALPHA, + ) + + model_input = base_model.get_model_input() + lora_model = qwix.apply_lora_to_model( + base_model, lora_provider, **model_input + ) + + return lora_model + + +model_version = "meta-llama/Llama-3.2-3B-Instruct" + +repo_id = model_version +model_tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id) + +import tempfile + +temp_dir = tempfile.gettempdir() +model_path = os.path.join(temp_dir, "models", repo_id) + +match_format = re.compile( + rf"^[\s]{{0,}}" + rf"{reasoning_start}.+?{reasoning_end}.*?" + rf"{solution_start}(.+?){solution_end}" + rf"[\s]{{0,}}$", + flags=re.MULTILINE | re.DOTALL, +) + + +def match_format_exactly(prompts, completions, **kwargs): + return [ + 0 if match_format.search(response) is None else 3.0 + for response in completions + ] + + +class SglangJaxLoRATest(absltest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + VERIFY_UPDATE_PARAMS_VAL = "layers.0.mlp.gate_proj.kernel_lora_a,model.layers.0.mlp.gate_proj.A_buffer" + os.environ[VERIFY_UPDATE_PARAMS_KEY] = VERIFY_UPDATE_PARAMS_VAL + + super().setUpClass() + + ## ====================================== Get Dataset =================================== + source = "kaggle" + cls.dataset = get_dataset(TRAIN_DATA_DIR, "train", source).batch( + TRAIN_MICRO_BATCH_SIZE + )[:NUM_BATCHES] + + ## ====================================== Get LoRA model ================================= + download_from_huggingface(repo_id=repo_id, model_path=model_path) + + ref_model, cls.mesh, _ = load_model() + rollout_mesh = get_rollout_mesh() + + lora_policy = get_lora_model(ref_model) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + mapping_config = mappings.MappingConfig.build( + model=ref_model, backend="sglang_jax" + ) + + ## ========================== Iniitialize RL cluster and trainer ========================== + cluster_config, grpo_config = cls.prepare_configs( + cls.mesh, rollout_mesh, mapping_config + ) + + # RL cluster + rl_cluster = rl_cluster_lib.RLCluster( + actor=lora_policy, + reference=ref_model, + tokenizer=tokenizer, + cluster_config=cluster_config, + ) + + # GRPO Trainer + cls.grpo_trainer = GRPOLearner( + rl_cluster=rl_cluster, + reward_fns=[ + match_format_exactly, + ], + algo_config=grpo_config, + ) + + @classmethod + def prepare_configs(cls, mesh, rollout_mesh, mapping_config): + # Ckpt saving + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP + ) + + # Metrics logger + metrics_logging_options = metrics_logger.MetricsLoggerOptions( + log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20 + ) + + optimizer = optax.adamw( + learning_rate=optax.schedules.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=LEARNING_RATE, + warmup_steps=WARMUP_STEPS, + decay_steps=MAX_STEPS, + end_value=0.0, + ), + b1=B1, + b2=B2, + weight_decay=WEIGHT_DECAY, + ) + + if MAX_GRAD_NORM is not None: + optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM), + optimizer, + ) + + # Training config + cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: mesh, + rl_cluster_lib.Role.REFERENCE: mesh, + rl_cluster_lib.Role.ROLLOUT: rollout_mesh, + }, + rollout_engine="sglang_jax", + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optimizer, + eval_every_n_steps=EVAL_EVERY_N_STEPS, + max_steps=MAX_STEPS, + mini_batch_size=TRAIN_MICRO_BATCH_SIZE, + train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE, + metrics_logging_options=metrics_logging_options, + checkpoint_root_directory=CKPT_DIR, + checkpointing_options=checkpointing_options, + ), + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=TOTAL_GENERATION_STEPS, + max_prompt_length=MAX_PROMPT_LENGTH, + kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256, + temperature=TEMPERATURE, + top_p=TOP_P, + top_k=TOP_K, + rollout_mapping_config=mapping_config, + rollout_sglang_jax_model_version=model_path, + rollout_sglang_jax_context_length=2048, + rollout_sglang_jax_mem_fraction_static=0.3, + rollout_sglang_jax_init_with_random_weights=True, + rollout_sglang_jax_disable_radix_cache=True, + rollout_sglang_jax_enable_deterministic_sampling=False, + rollout_sglang_jax_use_sort_for_toppk_minp=True, + rollout_sglang_jax_enable_static_lora=True, + rollout_sglang_jax_enable_single_process=True, + rollout_sglang_jax_lora_target_modules=["all"], + rollout_sglang_jax_max_lora_rank=RANK, + rollout_sglang_jax_lora_scaling=ALPHA / RANK, + rollout_sglang_jax_precompile_bs_paddings=[8], + rollout_sglang_jax_precompile_token_paddings=[2048], + ), + ) + + grpo_config = GRPOConfig( + num_generations=NUM_GENERATIONS, + num_iterations=NUM_ITERATIONS, + beta=BETA, + epsilon=EPSILON, + ) + + return (cluster_config, grpo_config) + + def test_lora(self): + with self.mesh: + self.grpo_trainer.train(self.dataset) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/generate/mappings.py b/tunix/generate/mappings.py index e71887eaf..97e7999cc 100644 --- a/tunix/generate/mappings.py +++ b/tunix/generate/mappings.py @@ -47,6 +47,11 @@ def to_hf_transpose_keys(cls, backend: str | None = None): result = cls.mapping_for(backend).get('to_hf_transpose_keys') return result or None + @classmethod + def lora_to_hf_transpose_keys(cls, backend: str | None = None): + result = cls.mapping_for(backend).get('lora_to_hf_transpose_keys') + return result or None + @classmethod def to_hf_hook_fns(cls, backend: str | None = None): return cls.mapping_for(backend).get('to_hf_hook_fns') @@ -66,6 +71,7 @@ class MappingConfig: lora_to_hf_mappings: Optional[Dict[str, Any]] = None to_hf_hook_fns: Optional[Dict[str, Any]] = None to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None + lora_to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None @classmethod def build( @@ -90,6 +96,7 @@ def build( 'lora_to_hf_mappings', 'to_hf_hook_fns', 'to_hf_transpose_keys', + 'lora_to_hf_transpose_keys', ) values: Dict[str, Any] = {} @@ -116,6 +123,7 @@ def build( lora_to_hf_mappings=resolved.get('lora_to_hf_mappings'), to_hf_hook_fns=resolved.get('to_hf_hook_fns'), to_hf_transpose_keys=resolved.get('to_hf_transpose_keys'), + lora_to_hf_transpose_keys=resolved.get('lora_to_hf_transpose_keys'), ) @classmethod @@ -143,6 +151,7 @@ def maybe_call(attr: str): lora_to_hf_mappings=maybe_call('lora_to_hf_mappings'), to_hf_hook_fns=maybe_call('to_hf_hook_fns'), to_hf_transpose_keys=maybe_call('to_hf_transpose_keys'), + lora_to_hf_transpose_keys=maybe_call('lora_to_hf_transpose_keys'), ) for key, value in overrides.items(): diff --git a/tunix/generate/sglang_jax_sampler.py b/tunix/generate/sglang_jax_sampler.py index cfa92da8a..8fd58b755 100644 --- a/tunix/generate/sglang_jax_sampler.py +++ b/tunix/generate/sglang_jax_sampler.py @@ -15,39 +15,53 @@ """Sampler for sglang-jax-style autoregressive decoding using JAX and NNX models.""" import dataclasses +import logging import math import os +import re from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -from absl import logging from flax import nnx import jax import jax.numpy as jnp import jaxtyping import numpy as np from sgl_jax.srt.entrypoints.engine import Engine +from sgl_jax.srt.utils.common_utils import SUPPORTED_LORA_TARGET_MODULES from tunix.generate import base_sampler from tunix.generate import mappings from tunix.generate import utils import tunix.generate.tokenizer_adapter as tok_adapter from tunix.rl import reshard +from tunix.rl.utils import VERIFY_UPDATE_PARAMS_KEY @dataclasses.dataclass class SglangJaxConfig: - model_version: str - context_length: int mesh: jax.sharding.Mesh - mem_fraction_static: float - init_with_random_weights: bool - disable_radix_cache: bool - enable_deterministic_sampling: bool mapping_config: mappings.MappingConfig + + model_version: str + context_length: int + + mem_fraction_static: float = 0.2 + init_with_random_weights: bool = True + disable_radix_cache: bool = True + enable_deterministic_sampling: bool = False # Note: use_sort_for_toppk_minp may be removed in the future. It depends on SGLang-Jax. use_sort_for_toppk_minp: bool = True - precompile_bs_paddings: Optional[List] = None - precompile_token_paddings: Optional[List] = None - chunked_prefill_size: int = -1 + enable_static_lora: bool = False + enable_single_process: bool = ( + True # Note: this is required when you run it in pathways. + ) + + lora_target_modules: Optional[List[str]] = None + max_lora_rank: Optional[int] = None + lora_scaling: Optional[float] = None + + precompile_token_paddings: Optional[List[int]] = None + precompile_bs_paddings: Optional[List[int]] = None + chunked_prefill_size: Optional[int] = -1 page_size: int = 64 @@ -76,10 +90,27 @@ def __init__( self.args = self._sglang_jax_config(config) self.engine = Engine(**self.args) - self.mappings = config.mapping_config.to_hf_mappings + self.to_hf_key_mappings = config.mapping_config.to_hf_mappings + self.to_hf_key_mappings = update_hf_key_mappings_with_lora( + self.to_hf_key_mappings, + self.args["enable_static_lora"], + self.args["lora_target_modules"], + ) self.to_hf_transpose_keys = config.mapping_config.to_hf_transpose_keys self.to_hf_hook_fns = config.mapping_config.to_hf_hook_fns + if config.mapping_config.lora_to_hf_mappings: + self.to_hf_key_mappings |= config.mapping_config.lora_to_hf_mappings + + if config.mapping_config.lora_to_hf_transpose_keys: + self.to_hf_transpose_keys |= ( + config.mapping_config.lora_to_hf_transpose_keys + ) + + self._logger = logging.getLogger(self.__class__.__name__) + + self._logger.debug(f"{self.to_hf_key_mappings=}") + # TODO(b/434969743): Optimize weight sharing between trainer and sglang-jax sampler. # TODO(b/434975493): Consider Release KV cache on the fly def update_params( @@ -91,13 +122,81 @@ def update_params( new_state = utils.transfer_state_with_mappings( src_state=updated_weights, dst_state=self.transformer_state, - key_mappings=self.mappings, + key_mappings=self.to_hf_key_mappings, transpose_keys=self.to_hf_transpose_keys, reshard_fn=reshard.reshard_pytree, + rollout_engine="sglang_jax", ) new_model_state_leaves, _ = jax.tree_util.tree_flatten(new_state) self._model_runner.model_state_leaves = new_model_state_leaves + flatten_src_to_tgt_module_name = os.getenv(VERIFY_UPDATE_PARAMS_KEY, None) + if flatten_src_to_tgt_module_name is not None: + # update state before verification + nnx.update(self._model_runner.model, new_state) + + self.verify_update_params( + updated_weights, + self.transformer_state, + flatten_src_to_tgt_module_name, + ) + + def verify_update_params( + self, + src_state: nnx.State, + tgt_state: nnx.State, + flatten_src_to_tgt_module_name: str, + ): + self._logger.debug( + f"[verify_update_params] {flatten_src_to_tgt_module_name} is required" + " to verify" + ) + src_tgt = flatten_src_to_tgt_module_name.split(",") + assert len(src_tgt) == 2 + flatten_src_module_name = src_tgt[0] + flatten_tgt_module_name = src_tgt[1] + flatten_src_state_list = src_state.flat_state() + src_value = None + for keys, param in flatten_src_state_list: + path = ".".join(str(key) for key in keys) + if path == flatten_src_module_name: + src_value = param.value if hasattr(param, "value") else param + if src_value is None: + raise ValueError( + f"{flatten_src_module_name=} does not exist in src: {src_state=}" + ) + + flatten_tgt_state_list = tgt_state.flat_state() + tgt_value = None + for keys, param in flatten_tgt_state_list: + path = ".".join(str(key) for key in keys) + if path == flatten_tgt_module_name: + tgt_value = param.value if hasattr(param, "value") else param + + if tgt_value is None: + raise ValueError( + f"{flatten_tgt_module_name=} does not exist in tgt: {tgt_state=}" + ) + + if "lora" in flatten_src_module_name: + for r_key in self.to_hf_transpose_keys: + if re.match(r_key, flatten_src_module_name): + logging.info( + "Applying LoRA transpose on %s in verification", + flatten_src_module_name, + ) + transposed_src_value = jnp.transpose( + src_value[None, :, :], self.to_hf_transpose_keys[r_key] + ) + else: + transposed_src_value = src_value + + src_value_np, tgt_value_np = jax.device_get( + transposed_src_value + ), jax.device_get(tgt_value) + if not np.array_equal(src_value_np, tgt_value_np): + raise ValueError(f"{src_value_np=} is not equal to {tgt_value_np=}") + def load_checkpoint(self, path_or_weights: str | jaxtyping.PyTree): # TODO(b/434741253): Consider support orbax checkpoint loading if isinstance(path_or_weights, jaxtyping.PyTree): @@ -114,21 +213,50 @@ def _sglang_jax_config(self, config: SglangJaxConfig): args = {} args["model_path"] = config.model_version args["context_length"] = config.context_length - args["tp_size"] = self._find_tp_size(config.mesh) - args["device_indexes"] = config.mesh.device_ids.flatten().tolist() + args["mem_fraction_static"] = config.mem_fraction_static - args["enable_single_process"] = True - if config.disable_radix_cache: - args["disable_radix_cache"] = True - if config.enable_deterministic_sampling: - args["enable_deterministic_sampling"] = True if config.init_with_random_weights: args["load_format"] = "dummy" + args["disable_radix_cache"] = config.disable_radix_cache + args["enable_deterministic_sampling"] = config.enable_deterministic_sampling args["use_sort_for_toppk_minp"] = config.use_sort_for_toppk_minp - args["precompile_bs_paddings"] = config.precompile_bs_paddings - args["precompile_token_paddings"] = config.precompile_token_paddings + args["enable_static_lora"] = config.enable_static_lora + args["enable_single_process"] = config.enable_single_process + + if config.enable_static_lora: + assert ( + config.lora_target_modules is not None + and config.max_lora_rank is not None + and config.lora_scaling is not None + ) + # check whether the lora_target_modules are valid + if config.lora_target_modules == ["all"]: + config.lora_target_modules = SUPPORTED_LORA_TARGET_MODULES + else: + for module in config.lora_target_modules: + if module not in SUPPORTED_LORA_TARGET_MODULES: + raise ValueError( + f"{module} in lora_target_modules does not exist in" + f" {SUPPORTED_LORA_TARGET_MODULES}" + ) + args["lora_target_modules"] = config.lora_target_modules + args["max_lora_rank"] = config.max_lora_rank + args["max_loras_per_batch"] = 1 + args["lora_scaling"] = config.lora_scaling + + if config.precompile_token_paddings is not None: + assert isinstance(config.precompile_token_paddings, List) + args["precompile_token_paddings"] = config.precompile_token_paddings + if config.precompile_bs_paddings is not None: + assert isinstance(config.precompile_bs_paddings, List) + args["precompile_bs_paddings"] = config.precompile_bs_paddings args["chunked_prefill_size"] = config.chunked_prefill_size + + # default arguments which is derived from known configuration or to tune. args["page_size"] = config.page_size + args["tp_size"] = self._find_tp_size(config.mesh) + args["device_indexes"] = config.mesh.device_ids.flatten().tolist() + return args @property @@ -256,3 +384,29 @@ def __call__( padded_prompt_tokens=all_input_ids, logprobs=None, ) + + +def update_hf_key_mappings_with_lora( + mappings: Optional[Dict[str, Any]] = None, + enable_static_lora: bool = False, + lora_target_modules: Optional[List] = None, +): + if ( + mappings is None + or not enable_static_lora + or lora_target_modules is None + or len(lora_target_modules) == 0 + ): + return mappings + + # Note: SGLangJax implements the LoRA through wraping the base_layer, so the value in mappings needs to be updated. + # From 'model.layers.*.mlp.gate_proj.weight' to 'model.layers.*.mlp.gate_proj.base_layer.weight' + for module in lora_target_modules: + for src_path, tgt_params in mappings.items(): + if module in src_path: + tgt_path, sharding = tgt_params + keys = tgt_path.split(".") + new_tgt_path = ".".join(keys[:-1]) + ".base_layer." + keys[-1] + mappings[src_path] = (new_tgt_path, sharding) + break + return mappings diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 7e4bfde80..7669e631a 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -464,6 +464,7 @@ def _apply_transpose( val: jnp.ndarray, src_key: str, transpose_keys: Optional[Dict[str, Tuple[int, ...]]], + rollout_engine: Optional[str], ) -> jnp.ndarray: """Apply transpose operation if configured for this key.""" if not transpose_keys: @@ -479,6 +480,16 @@ def _apply_transpose( if target_key != '': logging.debug('Applying transpose on %s', src_key) return jnp.transpose(val, transpose_keys[target_key]) + + # For LoRA + # Note: The following codes takes effect in SGLangJAx rollout, and may not take effect in other rollout engine. + + if rollout_engine == 'sglang_jax' and 'lora' in all_key: + for r_key in transpose_keys: + if re.compile(rf'{r_key}').match(all_key): + logging.debug('Applying LoRA transpose on %s', src_key) + return jnp.transpose(val[None, :, :], transpose_keys[r_key]) + return val @@ -603,7 +614,6 @@ def _align_shape( def _apply_dtype_cast( val: jnp.ndarray, tgt_dtype: jnp.dtype, src_key: str ) -> jnp.ndarray: - if val.dtype != tgt_dtype: logging.warning( 'Type mismatch on %s: %s -> %s', @@ -622,6 +632,7 @@ def transfer_state_with_mappings( key_mapping_hook_fns=None, transpose_keys=None, reshard_fn=None, + rollout_engine=None, ): """Transfer state using mappings, with optional transpose and shard logic. @@ -643,8 +654,10 @@ def transfer_state_with_mappings( """ # Get flat target state tgt_flat_list = dst_state.flat_state() + # Build sharding dictionary if resharding is needed sharding_dict = None + if reshard_fn: sharding_dict = { key: ( @@ -667,7 +680,7 @@ def transfer_state_with_mappings( tgt_param, ) in unscanned_src_to_tgt_flat.items(): # Apply transpose if configured - val = _apply_transpose(val, flat_src_key, transpose_keys) + val = _apply_transpose(val, flat_src_key, transpose_keys, rollout_engine) # Apply optional hook function if key_mapping_hook_fns and flat_src_key in key_mapping_hook_fns: @@ -725,6 +738,7 @@ def transfer_state_directly( dst_state: The destination state to transfer to. reshard_fn: A function to shard the values. """ + def safe_has_key(obj: Mapping[str, Any], key: str) -> bool: if isinstance(obj, dict): return key in obj @@ -732,12 +746,16 @@ def safe_has_key(obj: Mapping[str, Any], key: str) -> bool: return hasattr(obj, key) # Unwrap Source (Remove 'base' wrapper from MaxText) - if isinstance(src_state, (dict, nnx.State, nnx.Dict)) and safe_has_key(src_state, 'base'): + if isinstance(src_state, (dict, nnx.State, nnx.Dict)) and safe_has_key( + src_state, 'base' + ): logging.info("Unwrapping 'base' key from source state.") src_state = src_state['base'] # Unwrap Target (Remove nested 'model' wrappers from vLLM) - while isinstance(dst_state, (dict, nnx.State, nnx.Dict)) and safe_has_key(dst_state, 'model'): + while isinstance(dst_state, (dict, nnx.State, nnx.Dict)) and safe_has_key( + dst_state, 'model' + ): logging.info("Unwrapping nested 'model' key from target state.") dst_state = dst_state['model'] @@ -763,7 +781,9 @@ def to_pure_spec(node: Any) -> Any: return node # Helper: Intersect Trees (Handle KVCache/RNG mismatches) - def intersect_trees(src: Any, tgt_spec: Any, path: str = "") -> Tuple[Any, Any]: + def intersect_trees( + src: Any, tgt_spec: Any, path: str = '' + ) -> Tuple[Any, Any]: # Stop recursion if we hit a leaf (non-dict) if not isinstance(src, dict) or not isinstance(tgt_spec, dict): return src, tgt_spec @@ -786,7 +806,7 @@ def intersect_trees(src: Any, tgt_spec: Any, path: str = "") -> Tuple[Any, Any]: filtered_tgt = {} for k in common_keys: - new_path = f"{path}/{k}" if path else k + new_path = f'{path}/{k}' if path else k s_val, t_val = intersect_trees(src[k], tgt_spec[k], new_path) filtered_src[k] = s_val filtered_tgt[k] = t_val diff --git a/tunix/models/llama3/mapping_sglang_jax.py b/tunix/models/llama3/mapping_sglang_jax.py index fe22f764c..3dd6e47bd 100644 --- a/tunix/models/llama3/mapping_sglang_jax.py +++ b/tunix/models/llama3/mapping_sglang_jax.py @@ -5,16 +5,18 @@ import os from typing import Any, Dict, Tuple +from tunix.utils.env_utils import SGLANG_JAX_TP_AXIS_NAME + Sharding = Tuple[str | None, ...] MappingEntry = Tuple[str, Sharding] def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: return { - 'lm_head.w': ('lm_head.embedding', (None, 'model')), + 'lm_head.w': ('lm_head.embedding', (None, SGLANG_JAX_TP_AXIS_NAME)), 'embedder.input_embedding': ( 'model.embed_tokens.embedding', - ('model', None), + (SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.input_layernorm.w': ( 'model.layers.*.input_layernorm.scale', @@ -22,15 +24,15 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: ), 'layers.*.mlp.down_proj.kernel': ( 'model.layers.*.mlp.down_proj.weight', - ('model', None), + (SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.mlp.gate_proj.kernel': ( 'model.layers.*.mlp.gate_proj.weight', - (None, 'model'), + (None, SGLANG_JAX_TP_AXIS_NAME), ), 'layers.*.mlp.up_proj.kernel': ( 'model.layers.*.mlp.up_proj.weight', - (None, 'model'), + (None, SGLANG_JAX_TP_AXIS_NAME), ), 'layers.*.post_attention_layernorm.w': ( 'model.layers.*.post_attention_layernorm.scale', @@ -38,19 +40,19 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: ), 'layers.*.attn.k_proj.w': ( 'model.layers.*.self_attn.k_proj.weight', - (None, 'model', None), + (None, SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.attn.o_proj.w': ( 'model.layers.*.self_attn.o_proj.weight', - ('model', None, None), + (SGLANG_JAX_TP_AXIS_NAME, None, None), ), 'layers.*.attn.q_proj.w': ( 'model.layers.*.self_attn.q_proj.weight', - (None, 'model', None), + (None, SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.attn.v_proj.w': ( 'model.layers.*.self_attn.v_proj.weight', - (None, 'model', None), + (None, SGLANG_JAX_TP_AXIS_NAME, None), ), 'final_norm.w': ('model.norm.scale', (None,)), } @@ -58,7 +60,64 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None: """The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend""" - return None + return { + 'layers.*.mlp.gate_proj.kernel_lora_a': ( + 'model.layers.*.mlp.gate_proj.A_buffer', + (None, None, None), + ), + 'layers.*.mlp.gate_proj.kernel_lora_b': ( + 'model.layers.*.mlp.gate_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.mlp.up_proj.kernel_lora_a': ( + 'model.layers.*.mlp.up_proj.A_buffer', + (None, None, None), + ), + 'layers.*.mlp.up_proj.kernel_lora_b': ( + 'model.layers.*.mlp.up_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.mlp.down_proj.kernel_lora_a': ( + 'model.layers.*.mlp.down_proj.A_buffer', + (None, None, SGLANG_JAX_TP_AXIS_NAME), + ), + 'layers.*.mlp.down_proj.kernel_lora_b': ( + 'model.layers.*.mlp.down_proj.B_buffer', + (None, None, None), + ), + 'layers.*.attn.q_proj.w_lora_a': ( + 'model.layers.*.self_attn.q_proj.A_buffer', + (None, None, None), + ), + 'layers.*.attn.q_proj.w_lora_b': ( + 'model.layers.*.self_attn.q_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.attn.k_proj.w_lora_a': ( + 'model.layers.*.self_attn.k_proj.A_buffer', + (None, None, None), + ), + 'layers.*.attn.k_proj.w_lora_b': ( + 'model.layers.*.self_attn.k_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.attn.v_proj.w_lora_a': ( + 'model.layers.*.self_attn.v_proj.A_buffer', + (None, None, None), + ), + 'layers.*.attn.v_proj.w_lora_b': ( + 'model.layers.*.self_attn.v_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.attn.o_proj.w_lora_a': ( + 'model.layers.*.self_attn.o_proj.A_buffer', + (None, None, SGLANG_JAX_TP_AXIS_NAME), + ), + 'layers.*.attn.o_proj.w_lora_b': ( + 'model.layers.*.self_attn.o_proj.B_buffer', + (None, None, None), + ), + } def _to_sglang_jax_transpose_keys(): @@ -67,6 +126,42 @@ def _to_sglang_jax_transpose_keys(): } +def _to_sglang_jax_lora_transpose_keys(): + """ + Tunix -> SGLangJax: + gate_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + gate_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank) + up_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + up_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank) + down_lora_a: (intermediate_size, max_lora_rank) -> (1, max_lora_rank, intermediate_size) + down_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank) + q_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + q_lora_b: (max_lora_rank, num_attention_heads, head_dim) -> (1, hidden_size, max_lora_rank) + k_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + k_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank) + v_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + v_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank) + o_lora_a: (num_attention_heads, head_dim, max_lora_rank) -> (1, max_lora_rank, hidden_size) + o_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank) + """ + return { + 'layers.*.mlp.gate_proj.kernel_lora_a': (0, 2, 1), + 'layers.*.mlp.gate_proj.kernel_lora_b': (0, 2, 1), + 'layers.*.mlp.up_proj.kernel_lora_a': (0, 2, 1), + 'layers.*.mlp.up_proj.kernel_lora_b': (0, 2, 1), + 'layers.*.mlp.down_proj.kernel_lora_a': (0, 2, 1), + 'layers.*.mlp.down_proj.kernel_lora_b': (0, 2, 1), + 'layers.*.attn.q_proj.w_lora_a': (0, 2, 1), + 'layers.*.attn.q_proj.w_lora_b': (0, 2, 3, 1), + 'layers.*.attn.k_proj.w_lora_a': (0, 2, 1), + 'layers.*.attn.k_proj.w_lora_b': (0, 2, 3, 1), + 'layers.*.attn.v_proj.w_lora_a': (0, 2, 1), + 'layers.*.attn.v_proj.w_lora_b': (0, 2, 3, 1), + 'layers.*.attn.o_proj.w_lora_a': (0, 3, 2, 1), + 'layers.*.attn.o_proj.w_lora_b': (0, 2, 1), + } + + def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None: """Additional parameter manipulation hook between Tunix vanilla model and Sglang Jax backend""" return None @@ -76,6 +171,7 @@ def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None: 'to_hf_mappings': _to_sglang_jax_mappings(), 'lora_to_hf_mappings': _lora_to_sglang_jax_mappings(), 'to_hf_transpose_keys': _to_sglang_jax_transpose_keys(), + 'lora_to_hf_transpose_keys': _to_sglang_jax_lora_transpose_keys(), 'to_hf_hook_fns': _to_sglang_jax_hook_fns(), } diff --git a/tunix/models/qwen2/mapping_sglang_jax.py b/tunix/models/qwen2/mapping_sglang_jax.py index 5b2ee26e5..71f96441b 100644 --- a/tunix/models/qwen2/mapping_sglang_jax.py +++ b/tunix/models/qwen2/mapping_sglang_jax.py @@ -5,16 +5,18 @@ import os from typing import Any, Dict, Tuple +from tunix.utils.env_utils import SGLANG_JAX_TP_AXIS_NAME + Sharding = Tuple[str | None, ...] MappingEntry = Tuple[str, Sharding] def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: return { - 'lm_head.w': ('lm_head.embedding', (None, 'model')), + 'lm_head.w': ('lm_head.embedding', (None, SGLANG_JAX_TP_AXIS_NAME)), 'embedder.input_embedding': ( 'model.embed_tokens.embedding', - ('model', None), + (SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.input_layernorm.w': ( 'model.layers.*.input_layernorm.scale', @@ -22,15 +24,15 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: ), 'layers.*.mlp.down_proj.kernel': ( 'model.layers.*.mlp.down_proj.weight', - ('model', None), + (SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.mlp.gate_proj.kernel': ( 'model.layers.*.mlp.gate_proj.weight', - (None, 'model'), + (None, SGLANG_JAX_TP_AXIS_NAME), ), 'layers.*.mlp.up_proj.kernel': ( 'model.layers.*.mlp.up_proj.weight', - (None, 'model'), + (None, SGLANG_JAX_TP_AXIS_NAME), ), 'layers.*.post_attention_layernorm.w': ( 'model.layers.*.post_attention_layernorm.scale', @@ -38,7 +40,7 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: ), 'layers.*.attn.k_proj.w': ( 'model.layers.*.self_attn.k_proj.weight', - (None, 'model', None), + (None, SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.attn.k_bias': ( 'model.layers.*.self_attn.k_proj.bias', @@ -46,11 +48,11 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: ), 'layers.*.attn.o_proj.w': ( 'model.layers.*.self_attn.o_proj.weight', - ('model', None, None), + (SGLANG_JAX_TP_AXIS_NAME, None, None), ), 'layers.*.attn.q_proj.w': ( 'model.layers.*.self_attn.q_proj.weight', - (None, 'model', None), + (None, SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.attn.q_bias': ( 'model.layers.*.self_attn.q_proj.bias', @@ -58,7 +60,7 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: ), 'layers.*.attn.v_proj.w': ( 'model.layers.*.self_attn.v_proj.weight', - (None, 'model', None), + (None, SGLANG_JAX_TP_AXIS_NAME, None), ), 'layers.*.attn.v_bias': ( 'model.layers.*.self_attn.v_proj.bias', @@ -70,7 +72,64 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None: """The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend""" - return None + return { + 'layers.*.mlp.gate_proj.kernel_lora_a': ( + 'model.layers.*.mlp.gate_proj.A_buffer', + (None, None, None), + ), + 'layers.*.mlp.gate_proj.kernel_lora_b': ( + 'model.layers.*.mlp.gate_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.mlp.up_proj.kernel_lora_a': ( + 'model.layers.*.mlp.up_proj.A_buffer', + (None, None, None), + ), + 'layers.*.mlp.up_proj.kernel_lora_b': ( + 'model.layers.*.mlp.up_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.mlp.down_proj.kernel_lora_a': ( + 'model.layers.*.mlp.down_proj.A_buffer', + (None, None, SGLANG_JAX_TP_AXIS_NAME), + ), + 'layers.*.mlp.down_proj.kernel_lora_b': ( + 'model.layers.*.mlp.down_proj.B_buffer', + (None, None, None), + ), + 'layers.*.attn.q_proj.w_lora_a': ( + 'model.layers.*.self_attn.q_proj.A_buffer', + (None, None, None), + ), + 'layers.*.attn.q_proj.w_lora_b': ( + 'model.layers.*.self_attn.q_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.attn.k_proj.w_lora_a': ( + 'model.layers.*.self_attn.k_proj.A_buffer', + (None, None, None), + ), + 'layers.*.attn.k_proj.w_lora_b': ( + 'model.layers.*.self_attn.k_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.attn.v_proj.w_lora_a': ( + 'model.layers.*.self_attn.v_proj.A_buffer', + (None, None, None), + ), + 'layers.*.attn.v_proj.w_lora_b': ( + 'model.layers.*.self_attn.v_proj.B_buffer', + (None, SGLANG_JAX_TP_AXIS_NAME, None), + ), + 'layers.*.attn.o_proj.w_lora_a': ( + 'model.layers.*.self_attn.o_proj.A_buffer', + (None, None, SGLANG_JAX_TP_AXIS_NAME), + ), + 'layers.*.attn.o_proj.w_lora_b': ( + 'model.layers.*.self_attn.o_proj.B_buffer', + (None, None, None), + ), + } def _to_sglang_jax_transpose_keys(): @@ -79,6 +138,42 @@ def _to_sglang_jax_transpose_keys(): } +def _to_sglang_jax_lora_transpose_keys(): + """ + Tunix -> SGLangJax: + gate_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + gate_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank) + up_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + up_lora_b: (max_lora_rank, intermediate_size) -> (1, intermediate_size, max_lora_rank) + down_lora_a: (intermediate_size, max_lora_rank) -> (1, max_lora_rank, intermediate_size) + down_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank) + q_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + q_lora_b: (max_lora_rank, num_attention_heads, head_dim) -> (1, hidden_size, max_lora_rank) + k_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + k_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank) + v_lora_a: (hidden_size, max_lora_rank) -> (1, max_lora_rank, hidden_size) + v_lora_b: (max_lora_rank, num_key_value_heads, head_dim) -> (1, num_key_value_heads*head_dim, max_lora_rank) + o_lora_a: (num_attention_heads, head_dim, max_lora_rank) -> (1, max_lora_rank, hidden_size) + o_lora_b: (max_lora_rank, hidden_size) -> (1, hidden_size, max_lora_rank) + """ + return { + 'layers.*.mlp.gate_proj.kernel_lora_a': (0, 2, 1), + 'layers.*.mlp.gate_proj.kernel_lora_b': (0, 2, 1), + 'layers.*.mlp.up_proj.kernel_lora_a': (0, 2, 1), + 'layers.*.mlp.up_proj.kernel_lora_b': (0, 2, 1), + 'layers.*.mlp.down_proj.kernel_lora_a': (0, 2, 1), + 'layers.*.mlp.down_proj.kernel_lora_b': (0, 2, 1), + 'layers.*.attn.q_proj.w_lora_a': (0, 2, 1), + 'layers.*.attn.q_proj.w_lora_b': (0, 2, 3, 1), + 'layers.*.attn.k_proj.w_lora_a': (0, 2, 1), + 'layers.*.attn.k_proj.w_lora_b': (0, 2, 3, 1), + 'layers.*.attn.v_proj.w_lora_a': (0, 2, 1), + 'layers.*.attn.v_proj.w_lora_b': (0, 2, 3, 1), + 'layers.*.attn.o_proj.w_lora_a': (0, 3, 2, 1), + 'layers.*.attn.o_proj.w_lora_b': (0, 2, 1), + } + + def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None: """Additional parameter manipulation hook between Tunix vanilla model and Sglang Jax backend""" return None @@ -88,6 +183,7 @@ def _to_sglang_jax_hook_fns() -> Dict[str, Any] | None: 'to_hf_mappings': _to_sglang_jax_mappings(), 'lora_to_hf_mappings': _lora_to_sglang_jax_mappings(), 'to_hf_transpose_keys': _to_sglang_jax_transpose_keys(), + 'lora_to_hf_transpose_keys': _to_sglang_jax_lora_transpose_keys(), 'to_hf_hook_fns': _to_sglang_jax_hook_fns(), } diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index d1371e20d..b911a89cd 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -20,8 +20,7 @@ from concurrent import futures import itertools import math -from typing import Any, Callable, Dict, Iterable, Iterator, List, Sequence -from typing import Generic, TypeVar +from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Sequence, TypeVar from absl import logging import jax diff --git a/tunix/rl/rollout/base_rollout.py b/tunix/rl/rollout/base_rollout.py index 5e9c9d0c2..7f095c9a0 100644 --- a/tunix/rl/rollout/base_rollout.py +++ b/tunix/rl/rollout/base_rollout.py @@ -162,15 +162,31 @@ class RolloutConfig: # Whether to enable deterministic sampling for SG-Lang JAX rollout engine. rollout_sglang_jax_enable_deterministic_sampling: bool = False - # List of token buckets for jax jit - rollout_sglang_jax_precompile_token_paddings: Optional[List] = None + # Whether to use sort or mask implementation in sampler, sort has better evaluation result. + rollout_sglang_jax_use_sort_for_toppk_minp: bool = True - # List of batch sizes buckets for jax jit - rollout_sglang_jax_precompile_bs_paddings: Optional[List] = None + # Whether to use lora + rollout_sglang_jax_enable_static_lora: bool = False - # The maximum number of tokens in a chunk for the chunked prefill. - # Setting this to -1 means disabling chunked prefill. - rollout_sglang_jax_chunked_prefill_size: int = -1 + # Whether to use single controller mode, single controller mode is required in pathways + rollout_sglang_jax_enable_single_process: bool = True + + # Specify the modules which are required to use lora + rollout_sglang_jax_lora_target_modules: Optional[List[str]] = None + + # Specify the lora RANK + rollout_sglang_jax_max_lora_rank: Optional[int] = None + + rollout_sglang_jax_lora_scaling: Optional[float] = None + + # Specify the paddings for batch_size + rollout_sglang_jax_precompile_bs_paddings: Optional[List[int]] = None + + # Specify the paddings for tokens which is used in prefll + rollout_sglang_jax_precompile_token_paddings: Optional[List[int]] = None + + # Specify the the maximum number of tokens in a chunk for the chunked prefill + rollout_sglang_jax_chunked_prefill_size: Optional[int] = -1 # The number of tokens in a page rollout_sglang_jax_page_size: int = 64 diff --git a/tunix/rl/rollout/sglang_jax_rollout.py b/tunix/rl/rollout/sglang_jax_rollout.py index 8850e6466..26471921b 100644 --- a/tunix/rl/rollout/sglang_jax_rollout.py +++ b/tunix/rl/rollout/sglang_jax_rollout.py @@ -45,13 +45,19 @@ def __init__( tokenizer=tokenizer, config=sglang_jax_sampler.SglangJaxConfig( mesh=mesh, - context_length=rollout_config.rollout_sglang_jax_context_length, + mapping_config=mapping_config, model_version=rollout_config.rollout_sglang_jax_model_version, + context_length=rollout_config.rollout_sglang_jax_context_length, mem_fraction_static=rollout_config.rollout_sglang_jax_mem_fraction_static, init_with_random_weights=rollout_config.rollout_sglang_jax_init_with_random_weights, disable_radix_cache=rollout_config.rollout_sglang_jax_disable_radix_cache, enable_deterministic_sampling=rollout_config.rollout_sglang_jax_enable_deterministic_sampling, - mapping_config=mapping_config, + use_sort_for_toppk_minp=rollout_config.rollout_sglang_jax_use_sort_for_toppk_minp, + enable_static_lora=rollout_config.rollout_sglang_jax_enable_static_lora, + enable_single_process=rollout_config.rollout_sglang_jax_enable_single_process, + lora_target_modules=rollout_config.rollout_sglang_jax_lora_target_modules, + max_lora_rank=rollout_config.rollout_sglang_jax_max_lora_rank, + lora_scaling=rollout_config.rollout_sglang_jax_lora_scaling, precompile_bs_paddings=rollout_config.rollout_sglang_jax_precompile_bs_paddings, precompile_token_paddings=rollout_config.rollout_sglang_jax_precompile_token_paddings, chunked_prefill_size=rollout_config.rollout_sglang_jax_chunked_prefill_size, @@ -119,3 +125,6 @@ def eos_id(self) -> int: def model(self) -> nnx.Module: return self._sampler.transformer + + def model_state(self): + return self._sampler.transformer_state diff --git a/tunix/rl/utils.py b/tunix/rl/utils.py index 1888368f7..f2e6fa1c2 100644 --- a/tunix/rl/utils.py +++ b/tunix/rl/utils.py @@ -258,3 +258,6 @@ def get_partition_spec( return sharding.spec else: return jax.sharding.PartitionSpec() + + +VERIFY_UPDATE_PARAMS_KEY = "VERIFY_UPDATE_PARAMS_SRC_TO_TGT_MODULE_NAME" diff --git a/tunix/utils/env_utils.py b/tunix/utils/env_utils.py index 1a058a02e..b49f05af4 100644 --- a/tunix/utils/env_utils.py +++ b/tunix/utils/env_utils.py @@ -14,6 +14,8 @@ """Environment utils.""" +import os + import flax @@ -31,3 +33,6 @@ def is_internal_env(): return True except ImportError: return False + + +SGLANG_JAX_TP_AXIS_NAME = os.getenv('SGLANG_JAX_TP_AXIS_NAME', 'tensor')