Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 65 additions & 19 deletions scripts/grpo_demo_sglang_jax_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@
# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 1
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 3738
# NUM_BATCHES = 3738
NUM_BATCHES = 10
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 2

EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
EVAL_EVERY_N_STEPS = 5 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1 # can potentially train for more epochs

# Number of training steps.
Expand Down Expand Up @@ -386,22 +387,22 @@ def download_from_huggingface(repo_id: str, model_path: str):
#


def get_lora_model(base_model, mesh):
# lora_provider = qwix.LoraProvider(
# module_path=(
# ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
# ".*attn_vec_einsum"
# ),
# rank=RANK,
# alpha=ALPHA,
# )
#
# model_input = base_model.get_model_input()
# lora_model = qwix.apply_lora_to_model(
# base_model, lora_provider, **model_input
# )
lora_model = base_model
return lora_model
# def get_lora_model(base_model, mesh):
# # lora_provider = qwix.LoraProvider(
# # module_path=(
# # ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
# # ".*attn_vec_einsum"
# # ),
# # rank=RANK,
# # alpha=ALPHA,
# # )
# #
# # model_input = base_model.get_model_input()
# # lora_model = qwix.apply_lora_to_model(
# # base_model, lora_provider, **model_input
# # )
# lora_model = base_model
# return lora_model


# Reference model
Expand Down Expand Up @@ -453,7 +454,39 @@ def get_rollout_mesh():
# Policy model
rollout_mesh = get_rollout_mesh()

lora_policy = ref_model

def get_lora_model(base_model, model_mesh=None):
"""Creates a LoRA model from a base model.

Args:
base_model: The base model to apply LoRA to.
model_mesh: The mesh to use for sharding the model.

Returns:
A LoRA model.
"""
# if isinstance(base_model, llama_lib.Llama3):
# module_path = (
# ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj"
# )
# else:
# module_path = ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"

lora_provider = qwix.LoraProvider(
module_path=".*gate_proj",
rank=RANK,
alpha=ALPHA,
)

model_input = base_model.get_model_input()
lora_model = qwix.apply_lora_to_model(
base_model, lora_provider, **model_input
)

return lora_model


lora_policy = get_lora_model(ref_model, mesh)
print("after lora_policy")
show_hbm_usage()
# nnx.display(lora_policy)
Expand Down Expand Up @@ -734,6 +767,11 @@ def evaluate(
disable_radix_cache=True,
enable_deterministic_sampling=False,
mapping_config=mapping_config,
enable_lora=True,
lora_target_modules=["gate_proj"],
max_lora_rank=RANK,
precompile_bs_paddings=[2],
precompile_token_paddings=[2048],
)

# sampler = sampler_lib.SglangJaxSampler(
Expand Down Expand Up @@ -811,12 +849,20 @@ def evaluate(
temperature=TEMPERATURE,
top_p=TOP_P,
top_k=TOP_K,
rollout_mapping_config=mapping_config,
rollout_sglang_jax_model_version=model_path,
rollout_sglang_jax_context_length=2048,
rollout_sglang_jax_mem_fraction_static=0.4,
rollout_sglang_jax_init_with_random_weights=True,
rollout_sglang_jax_disable_radix_cache=True,
rollout_sglang_jax_enable_deterministic_sampling=False,
rollout_sglang_jax_use_sort_for_toppk_minp=True,
rollout_sglang_jax_enable_lora=True,
rollout_sglang_jax_enable_single_process=True,
rollout_sglang_jax_lora_target_modules=["gate_proj"],
rollout_sglang_jax_max_lora_rank=RANK,
rollout_sglang_jax_precompile_bs_paddings=[8],
rollout_sglang_jax_precompile_token_paddings=[2048],
),
)

Expand Down
1 change: 1 addition & 0 deletions tunix/generate/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def _backend_registry(cls) -> Dict[str, Any]:
module = cls.__module__
package_name = module.rsplit('.', 1)[0] if '.' in module else module
package = importlib.import_module(package_name)
print(f'[_backend_registry] {package=}')
return getattr(package, 'BACKEND_MAPPINGS', {})

@classmethod
Expand Down
71 changes: 53 additions & 18 deletions tunix/generate/sglang_jax_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,28 @@

@dataclasses.dataclass
class SglangJaxConfig:
model_version: str
context_length: int
mesh: jax.sharding.Mesh
mem_fraction_static: float
init_with_random_weights: bool
disable_radix_cache: bool
enable_deterministic_sampling: bool
mapping_config: mappings.MappingConfig

model_version: str
context_length: int

mem_fraction_static: float = 0.2
init_with_random_weights: bool = True
disable_radix_cache: bool = True
enable_deterministic_sampling: bool = False
# Note: use_sort_for_toppk_minp may be removed in the future. It depends on SGLang-Jax.
use_sort_for_toppk_minp: bool = True
enable_lora: bool = False
enable_single_process: bool = (
True # Note: this is required when you run it in pathways.
)

# lora_config: Optional[Dict[str, Any]] = None
lora_target_modules: Optional[List[str]] = None
max_lora_rank: Optional[int] = None
precompile_token_paddings: Optional[List[int]] = None
precompile_bs_paddings: Optional[List[int]] = None


class SglangJaxSampler(base_sampler.BaseSampler): # pylint: disable=invalid-name
Expand All @@ -68,14 +80,20 @@ def __init__(
tokenizer (Any): A tokenizer compatible with the model.
config: The sglang-jax related configurations
"""
print(f"[SglangJaxSampler] {config=}")
self.tokenizer = tok_adapter.TokenizerAdapter(tokenizer)
self.args = self._sglang_jax_config(config)
self.engine = Engine(**self.args)

self.mappings = config.mapping_config.to_hf_mappings
self.to_hf_key_mappings = config.mapping_config.to_hf_mappings
self.to_hf_transpose_keys = config.mapping_config.to_hf_transpose_keys
self.to_hf_hook_fns = config.mapping_config.to_hf_hook_fns

if config.mapping_config.lora_to_hf_mappings:
self.to_hf_key_mappings |= config.mapping_config.lora_to_hf_mappings

print(f"[SglangJaxSampler initialization] {self.to_hf_key_mappings=}")

# TODO(b/434969743): Optimize weight sharing between trainer and sglang-jax sampler.
# TODO(b/434975493): Consider Release KV cache on the fly
def update_params(
Expand All @@ -87,7 +105,7 @@ def update_params(
new_state = utils.transfer_state_with_mappings(
src_state=updated_weights,
dst_state=self.transformer_state,
key_mappings=self.mappings,
key_mappings=self.to_hf_key_mappings,
transpose_keys=self.to_hf_transpose_keys,
reshard_fn=reshard.reshard_pytree,
)
Expand All @@ -109,21 +127,38 @@ def _find_tp_size(self, mesh: jax.sharding.Mesh) -> int:
def _sglang_jax_config(self, config: SglangJaxConfig):
args = {}
args["model_path"] = config.model_version
args["precompile_bs_paddings"] = [1, 64]
args["precompile_token_paddings"] = [8192]
args["page_size"] = 64
args["context_length"] = config.context_length
args["tp_size"] = self._find_tp_size(config.mesh)
args["device_indexes"] = config.mesh.device_ids.flatten().tolist()

args["mem_fraction_static"] = config.mem_fraction_static
args["enable_single_process"] = True
if config.disable_radix_cache:
args["disable_radix_cache"] = True
if config.enable_deterministic_sampling:
args["enable_deterministic_sampling"] = True
if config.init_with_random_weights:
args["load_format"] = "dummy"
args["disable_radix_cache"] = config.disable_radix_cache
args["enable_deterministic_sampling"] = config.enable_deterministic_sampling
args["use_sort_for_toppk_minp"] = config.use_sort_for_toppk_minp
args["enable_lora"] = config.enable_lora
args["enable_single_process"] = config.enable_single_process

if config.enable_lora:
assert (
config.lora_target_modules is not None
and config.max_lora_rank is not None
)
args["lora_target_modules"] = config.lora_target_modules
args["max_lora_rank"] = config.max_lora_rank
args["max_loras_per_batch"] = 1

if config.precompile_token_paddings is not None:
assert isinstance(config.precompile_token_paddings, List)
args["precompile_token_paddings"] = config.precompile_token_paddings
if config.precompile_bs_paddings is not None:
assert isinstance(config.precompile_bs_paddings, List)
args["precompile_bs_paddings"] = config.precompile_bs_paddings

# default arguments which is derived from known configuration or to tune.
args["page_size"] = 64
args["tp_size"] = self._find_tp_size(config.mesh)
args["device_indexes"] = config.mesh.device_ids.flatten().tolist()

return args

@property
Expand Down
5 changes: 5 additions & 0 deletions tunix/generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def build_flat_dict(

# Sort layers
for key, (layers, paths, sharding) in new_flat_dict.items():
print(f'[build_flat_dict][src] {key=}, [target]{paths=}')
if isinstance(layers, list):
layers.sort(key=lambda x: x[0])
paths.sort(key=lambda x: x[0])
Expand Down Expand Up @@ -426,6 +427,8 @@ def _unroll_scanned_layers(

unscanned_flat = {}

# print(f"{src_to_tgt_map=}")

for src_keys, src_val in src_state.flat_state():
src_key = '.'.join(str(k) for k in src_keys)

Expand Down Expand Up @@ -650,6 +653,8 @@ def transfer_state_with_mappings(
for key, tgt_params in tgt_flat_list
}

print(f'[src]{key_mappings=}')

# Build source-to-target mapping
src_to_tgt_map = build_flat_dict(tgt_flat_list, key_mappings)

Expand Down
11 changes: 10 additions & 1 deletion tunix/models/llama3/mapping_sglang_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:

def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None:
"""The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend"""
return None
return {
'layers.*.mlp.gate_proj.kernel_lora_a': (
'model.layers.*.mlp.gate_proj.A_buffer',
(None, 'model'),
),
'layers.*.mlp.gate_proj.kernel_lora_b': (
'model.layers.*.mlp.gate_proj.B_buffer',
(None, 'model'),
),
}


def _to_sglang_jax_transpose_keys():
Expand Down
11 changes: 10 additions & 1 deletion tunix/models/qwen2/mapping_sglang_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,16 @@ def _to_sglang_jax_mappings() -> Dict[str, MappingEntry]:

def _lora_to_sglang_jax_mappings() -> Dict[str, MappingEntry] | None:
"""The lora parameter key mapping between Tunix vanilla model and Sglang-jax Jax backend"""
return None
return {
'layers.*.mlp.gate_proj.kernel_lora_a': (
'model.layers.*.mlp.gate_proj.A_buffer',
(None, 'model'),
),
'layers.*.mlp.gate_proj.kernel_lora_b': (
'model.layers.*.mlp.gate_proj.B_buffer',
(None, 'model'),
),
}


def _to_sglang_jax_transpose_keys():
Expand Down
10 changes: 7 additions & 3 deletions tunix/rl/rl_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _init_cluster(self):
self._maybe_offload_model_to_cpu(self._rollout.model(), Role.ROLLOUT)
elif self.cluster_config.rollout_engine == "vllm":
from tunix.rl.rollout import vllm_rollout

loaded_vllm_config = None
if isinstance(
self.cluster_config.rollout_config, base_rollout.RolloutConfig
Expand All @@ -412,6 +413,7 @@ def _init_cluster(self):
)
elif self.cluster_config.rollout_engine == "sglang_jax":
from tunix.rl.rollout import sglang_jax_rollout

if isinstance(
self.cluster_config.rollout_config, base_rollout.RolloutConfig
):
Expand All @@ -421,9 +423,7 @@ def _init_cluster(self):
Mode.TRAIN
]
else:
raise ValueError(
"Rollout sglang jax model config is missing!"
)
raise ValueError("Rollout sglang jax model config is missing!")

self._rollout = sglang_jax_rollout.SglangJaxRollout(
self.rollout_actor,
Expand Down Expand Up @@ -912,6 +912,10 @@ def get_old_per_token_logps(
return per_token_logps

def sync_weights(self):
print(
f"===================================coming into"
f" sync_weights======================"
)
"""Syncs the weights of between the sampler model and trainer model."""
if jax.devices() and jax.default_backend() not in ["tpu", "gpu"]:
cm = contextlib.ExitStack()
Expand Down
Loading
Loading