diff --git a/scripts/grpo_demo_sglang_jax_rollout.py b/scripts/grpo_demo_sglang_jax_rollout.py index a77989d0..37f347fa 100644 --- a/scripts/grpo_demo_sglang_jax_rollout.py +++ b/scripts/grpo_demo_sglang_jax_rollout.py @@ -115,12 +115,13 @@ # ====== Training ====== TRAIN_MICRO_BATCH_SIZE = 1 # Increase `NUM_BATCHES` and `MAX_STEPS` for better results. -NUM_BATCHES = 3738 +# NUM_BATCHES = 3738 +NUM_BATCHES = 10 # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). NUM_TEST_BATCHES = 2 -EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. +EVAL_EVERY_N_STEPS = 5 # this doesn't matter if `TRAIN_FRACTION = 1.0`. NUM_EPOCHS = 1 # can potentially train for more epochs # Number of training steps. @@ -386,22 +387,22 @@ def download_from_huggingface(repo_id: str, model_path: str): # -def get_lora_model(base_model, mesh): - # lora_provider = qwix.LoraProvider( - # module_path=( - # ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|" - # ".*attn_vec_einsum" - # ), - # rank=RANK, - # alpha=ALPHA, - # ) - # - # model_input = base_model.get_model_input() - # lora_model = qwix.apply_lora_to_model( - # base_model, lora_provider, **model_input - # ) - lora_model = base_model - return lora_model +# def get_lora_model(base_model, mesh): +# # lora_provider = qwix.LoraProvider( +# # module_path=( +# # ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|" +# # ".*attn_vec_einsum" +# # ), +# # rank=RANK, +# # alpha=ALPHA, +# # ) +# # +# # model_input = base_model.get_model_input() +# # lora_model = qwix.apply_lora_to_model( +# # base_model, lora_provider, **model_input +# # ) +# lora_model = base_model +# return lora_model # Reference model @@ -453,7 +454,39 @@ def get_rollout_mesh(): # Policy model rollout_mesh = get_rollout_mesh() -lora_policy = ref_model + +def get_lora_model(base_model, model_mesh=None): + """Creates a LoRA model from a base model. + + Args: + base_model: The base model to apply LoRA to. + model_mesh: The mesh to use for sharding the model. + + Returns: + A LoRA model. + """ + # if isinstance(base_model, llama_lib.Llama3): + # module_path = ( + # ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" + # ) + # else: + # module_path = ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" + + lora_provider = qwix.LoraProvider( + module_path=".*gate_proj", + rank=RANK, + alpha=ALPHA, + ) + + model_input = base_model.get_model_input() + lora_model = qwix.apply_lora_to_model( + base_model, lora_provider, **model_input + ) + + return lora_model + + +lora_policy = get_lora_model(ref_model, mesh) print("after lora_policy") show_hbm_usage() # nnx.display(lora_policy) @@ -734,6 +767,11 @@ def evaluate( disable_radix_cache=True, enable_deterministic_sampling=False, mapping_config=mapping_config, + enable_lora=True, + lora_target_modules=["gate_proj"], + max_lora_rank=RANK, + precompile_bs_paddings=[2], + precompile_token_paddings=[2048], ) # sampler = sampler_lib.SglangJaxSampler( @@ -811,12 +849,20 @@ def evaluate( temperature=TEMPERATURE, top_p=TOP_P, top_k=TOP_K, + rollout_mapping_config=mapping_config, rollout_sglang_jax_model_version=model_path, rollout_sglang_jax_context_length=2048, rollout_sglang_jax_mem_fraction_static=0.4, rollout_sglang_jax_init_with_random_weights=True, rollout_sglang_jax_disable_radix_cache=True, rollout_sglang_jax_enable_deterministic_sampling=False, + rollout_sglang_jax_use_sort_for_toppk_minp=True, + rollout_sglang_jax_enable_lora=True, + rollout_sglang_jax_enable_single_process=True, + rollout_sglang_jax_lora_target_modules=["gate_proj"], + rollout_sglang_jax_max_lora_rank=RANK, + rollout_sglang_jax_precompile_bs_paddings=[8], + rollout_sglang_jax_precompile_token_paddings=[2048], ), ) diff --git a/tunix/generate/mappings.py b/tunix/generate/mappings.py index e71887ea..6c490e2e 100644 --- a/tunix/generate/mappings.py +++ b/tunix/generate/mappings.py @@ -17,6 +17,7 @@ def _backend_registry(cls) -> Dict[str, Any]: module = cls.__module__ package_name = module.rsplit('.', 1)[0] if '.' in module else module package = importlib.import_module(package_name) + print(f'[_backend_registry] {package=}') return getattr(package, 'BACKEND_MAPPINGS', {}) @classmethod diff --git a/tunix/generate/sglang_jax_sampler.py b/tunix/generate/sglang_jax_sampler.py index 83a5ccb5..aa7b0719 100644 --- a/tunix/generate/sglang_jax_sampler.py +++ b/tunix/generate/sglang_jax_sampler.py @@ -35,16 +35,28 @@ @dataclasses.dataclass class SglangJaxConfig: - model_version: str - context_length: int mesh: jax.sharding.Mesh - mem_fraction_static: float - init_with_random_weights: bool - disable_radix_cache: bool - enable_deterministic_sampling: bool mapping_config: mappings.MappingConfig + + model_version: str + context_length: int + + mem_fraction_static: float = 0.2 + init_with_random_weights: bool = True + disable_radix_cache: bool = True + enable_deterministic_sampling: bool = False # Note: use_sort_for_toppk_minp may be removed in the future. It depends on SGLang-Jax. use_sort_for_toppk_minp: bool = True + enable_lora: bool = False + enable_single_process: bool = ( + True # Note: this is required when you run it in pathways. + ) + + # lora_config: Optional[Dict[str, Any]] = None + lora_target_modules: Optional[List[str]] = None + max_lora_rank: Optional[int] = None + precompile_token_paddings: Optional[List[int]] = None + precompile_bs_paddings: Optional[List[int]] = None class SglangJaxSampler(base_sampler.BaseSampler): # pylint: disable=invalid-name @@ -68,14 +80,20 @@ def __init__( tokenizer (Any): A tokenizer compatible with the model. config: The sglang-jax related configurations """ + print(f"[SglangJaxSampler] {config=}") self.tokenizer = tok_adapter.TokenizerAdapter(tokenizer) self.args = self._sglang_jax_config(config) self.engine = Engine(**self.args) - self.mappings = config.mapping_config.to_hf_mappings + self.to_hf_key_mappings = config.mapping_config.to_hf_mappings self.to_hf_transpose_keys = config.mapping_config.to_hf_transpose_keys self.to_hf_hook_fns = config.mapping_config.to_hf_hook_fns + if config.mapping_config.lora_to_hf_mappings: + self.to_hf_key_mappings |= config.mapping_config.lora_to_hf_mappings + + print(f"[SglangJaxSampler initialization] {self.to_hf_key_mappings=}") + # TODO(b/434969743): Optimize weight sharing between trainer and sglang-jax sampler. # TODO(b/434975493): Consider Release KV cache on the fly def update_params( @@ -87,7 +105,7 @@ def update_params( new_state = utils.transfer_state_with_mappings( src_state=updated_weights, dst_state=self.transformer_state, - key_mappings=self.mappings, + key_mappings=self.to_hf_key_mappings, transpose_keys=self.to_hf_transpose_keys, reshard_fn=reshard.reshard_pytree, ) @@ -109,21 +127,38 @@ def _find_tp_size(self, mesh: jax.sharding.Mesh) -> int: def _sglang_jax_config(self, config: SglangJaxConfig): args = {} args["model_path"] = config.model_version - args["precompile_bs_paddings"] = [1, 64] - args["precompile_token_paddings"] = [8192] - args["page_size"] = 64 args["context_length"] = config.context_length - args["tp_size"] = self._find_tp_size(config.mesh) - args["device_indexes"] = config.mesh.device_ids.flatten().tolist() + args["mem_fraction_static"] = config.mem_fraction_static - args["enable_single_process"] = True - if config.disable_radix_cache: - args["disable_radix_cache"] = True - if config.enable_deterministic_sampling: - args["enable_deterministic_sampling"] = True if config.init_with_random_weights: args["load_format"] = "dummy" + args["disable_radix_cache"] = config.disable_radix_cache + args["enable_deterministic_sampling"] = config.enable_deterministic_sampling args["use_sort_for_toppk_minp"] = config.use_sort_for_toppk_minp + args["enable_lora"] = config.enable_lora + args["enable_single_process"] = config.enable_single_process + + if config.enable_lora: + assert ( + config.lora_target_modules is not None + and config.max_lora_rank is not None + ) + args["lora_target_modules"] = config.lora_target_modules + args["max_lora_rank"] = config.max_lora_rank + args["max_loras_per_batch"] = 1 + + if config.precompile_token_paddings is not None: + assert isinstance(config.precompile_token_paddings, List) + args["precompile_token_paddings"] = config.precompile_token_paddings + if config.precompile_bs_paddings is not None: + assert isinstance(config.precompile_bs_paddings, List) + args["precompile_bs_paddings"] = config.precompile_bs_paddings + + # default arguments which is derived from known configuration or to tune. + args["page_size"] = 64 + args["tp_size"] = self._find_tp_size(config.mesh) + args["device_indexes"] = config.mesh.device_ids.flatten().tolist() + return args @property diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index ffee724f..f9b790e3 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -378,6 +378,7 @@ def build_flat_dict( # Sort layers for key, (layers, paths, sharding) in new_flat_dict.items(): + print(f'[build_flat_dict][src] {key=}, [target]{paths=}') if isinstance(layers, list): layers.sort(key=lambda x: x[0]) paths.sort(key=lambda x: x[0]) @@ -426,6 +427,8 @@ def _unroll_scanned_layers( unscanned_flat = {} + # print(f"{src_to_tgt_map=}") + for src_keys, src_val in src_state.flat_state(): src_key = '.'.join(str(k) for k in src_keys) @@ -650,6 +653,8 @@ def transfer_state_with_mappings( for key, tgt_params in tgt_flat_list } + print(f'[src]{key_mappings=}') + # Build source-to-target mapping src_to_tgt_map = build_flat_dict(tgt_flat_list, key_mappings) diff --git a/tunix/models/llama3/mapping_sglang_jax.py b/tunix/models/llama3/mapping_sglang_jax.py index fe22f764..f887ae63 100644 --- a/tunix/models/llama3/mapping_sglang_jax.py +++ b/tunix/models/llama3/mapping_sglang_jax.py @@ -58,7 +58,16 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None: """The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend""" - return None + return { + 'layers.*.mlp.gate_proj.kernel_lora_a': ( + 'model.layers.*.mlp.gate_proj.A_buffer', + (None, 'model'), + ), + 'layers.*.mlp.gate_proj.kernel_lora_b': ( + 'model.layers.*.mlp.gate_proj.B_buffer', + (None, 'model'), + ), + } def _to_sglang_jax_transpose_keys(): diff --git a/tunix/models/qwen2/mapping_sglang_jax.py b/tunix/models/qwen2/mapping_sglang_jax.py index 5b2ee26e..9695f833 100644 --- a/tunix/models/qwen2/mapping_sglang_jax.py +++ b/tunix/models/qwen2/mapping_sglang_jax.py @@ -70,7 +70,16 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]: def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None: """The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend""" - return None + return { + 'layers.*.mlp.gate_proj.kernel_lora_a': ( + 'model.layers.*.mlp.gate_proj.A_buffer', + (None, 'model'), + ), + 'layers.*.mlp.gate_proj.kernel_lora_b': ( + 'model.layers.*.mlp.gate_proj.B_buffer', + (None, 'model'), + ), + } def _to_sglang_jax_transpose_keys(): diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index 9ade5fd3..9ae48446 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -388,6 +388,7 @@ def _init_cluster(self): self._maybe_offload_model_to_cpu(self._rollout.model(), Role.ROLLOUT) elif self.cluster_config.rollout_engine == "vllm": from tunix.rl.rollout import vllm_rollout + loaded_vllm_config = None if isinstance( self.cluster_config.rollout_config, base_rollout.RolloutConfig @@ -412,6 +413,7 @@ def _init_cluster(self): ) elif self.cluster_config.rollout_engine == "sglang_jax": from tunix.rl.rollout import sglang_jax_rollout + if isinstance( self.cluster_config.rollout_config, base_rollout.RolloutConfig ): @@ -421,9 +423,7 @@ def _init_cluster(self): Mode.TRAIN ] else: - raise ValueError( - "Rollout sglang jax model config is missing!" - ) + raise ValueError("Rollout sglang jax model config is missing!") self._rollout = sglang_jax_rollout.SglangJaxRollout( self.rollout_actor, @@ -912,6 +912,10 @@ def get_old_per_token_logps( return per_token_logps def sync_weights(self): + print( + f"===================================coming into" + f" sync_weights======================" + ) """Syncs the weights of between the sampler model and trainer model.""" if jax.devices() and jax.default_backend() not in ["tpu", "gpu"]: cm = contextlib.ExitStack() diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index 51955d3d..3d5806f2 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -20,8 +20,7 @@ from concurrent import futures import itertools import math -from typing import Any, Callable, Dict, Iterable, Iterator, List, Sequence -from typing import Generic, TypeVar +from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Sequence, TypeVar from absl import logging import jax @@ -111,6 +110,8 @@ def __init__( ) ) + print(f"{self.should_sync_weights=}") + # Enable async rollout if trainer and rollout are not on the same mesh. # If they do, then doesn't make sense for the interleave because they will # have resource contention. @@ -542,6 +543,7 @@ def get_buffer_len(buf: dict[str, list[Any]]) -> int: buffer[key] = buffer[key][micro_batch_size:] yield micro_batch + def train( self, train_ds: Iterable[TrainingInputT], @@ -602,6 +604,10 @@ def train( initial_steps = self._iter_steps with self.rl_cluster.perf.span_group("global_step"): + print( + f"{full_batch_size=}, {mini_batch_size=}," + f" {service_target_batch_size=}" + ) self._run_global_step( full_batch_size, mini_batch_size, @@ -613,7 +619,11 @@ def train( ) if self.should_sync_weights: - logging.debug(f"Syncing weights at global step {self.rl_cluster.global_steps} mini batch step {self._iter_steps}") + logging.debug( + "Syncing weights at global step" + f" {self.rl_cluster.global_steps} mini batch step" + f" {self._iter_steps}" + ) with self.rl_cluster.perf.span( "weight_sync", self.rl_cluster.perf.all_devices ): diff --git a/tunix/rl/rollout/base_rollout.py b/tunix/rl/rollout/base_rollout.py index 45c1ed37..83c930ee 100644 --- a/tunix/rl/rollout/base_rollout.py +++ b/tunix/rl/rollout/base_rollout.py @@ -16,7 +16,7 @@ import abc import dataclasses -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import jax from jax import numpy as jnp @@ -153,6 +153,27 @@ class RolloutConfig: # Whether to enable deterministic sampling for SG-Lang JAX rollout engine. rollout_sglang_jax_enable_deterministic_sampling: bool = False + # Whether to use sort or mask implementation in sampler, sort has better evaluation result. + rollout_sglang_jax_use_sort_for_toppk_minp: bool = True + + # Whether to use lora + rollout_sglang_jax_enable_lora: bool = False + + # Whether to use single controller mode, single controller mode is required in pathways + rollout_sglang_jax_enable_single_process: bool = True + + # Specify the modules which are required to use lora + rollout_sglang_jax_lora_target_modules: Optional[List[str]] = None + + # Specify the lora RANK + rollout_sglang_jax_max_lora_rank: Optional[int] = None + + # Specify the paddings for batch_size + rollout_sglang_jax_precompile_bs_paddings: Optional[List[int]] = None + + # Specify the paddings for tokens which is used in prefll + rollout_sglang_jax_precompile_token_paddings: Optional[List[int]] = None + class BaseRollout(ABC): """Base RolloutWorker.""" diff --git a/tunix/rl/rollout/sglang_jax_rollout.py b/tunix/rl/rollout/sglang_jax_rollout.py index 29c95c79..f3b5585f 100644 --- a/tunix/rl/rollout/sglang_jax_rollout.py +++ b/tunix/rl/rollout/sglang_jax_rollout.py @@ -45,13 +45,20 @@ def __init__( tokenizer=tokenizer, config=sglang_jax_sampler.SglangJaxConfig( mesh=mesh, - context_length=rollout_config.rollout_sglang_jax_context_length, + mapping_config=mapping_config, model_version=rollout_config.rollout_sglang_jax_model_version, + context_length=rollout_config.rollout_sglang_jax_context_length, mem_fraction_static=rollout_config.rollout_sglang_jax_mem_fraction_static, init_with_random_weights=rollout_config.rollout_sglang_jax_init_with_random_weights, disable_radix_cache=rollout_config.rollout_sglang_jax_disable_radix_cache, enable_deterministic_sampling=rollout_config.rollout_sglang_jax_enable_deterministic_sampling, - mapping_config=mapping_config, + use_sort_for_toppk_minp=rollout_config.rollout_sglang_jax_use_sort_for_toppk_minp, + enable_lora=rollout_config.rollout_sglang_jax_enable_lora, + enable_single_process=rollout_config.rollout_sglang_jax_enable_single_process, + lora_target_modules=rollout_config.rollout_sglang_jax_lora_target_modules, + max_lora_rank=rollout_config.rollout_sglang_jax_max_lora_rank, + precompile_bs_paddings=rollout_config.rollout_sglang_jax_precompile_bs_paddings, + precompile_token_paddings=rollout_config.rollout_sglang_jax_precompile_token_paddings, ), ) state = nnx.state(model)