Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions mbridge/core/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
unwrap_model,
)

from mbridge.utils.hf_config import (
hf_moe_stacked_layout_default_from_transformers_version,
hf_moe_stacked_layout_from_checkpoint_keys,
)


class Bridge(ABC):
"""
Expand All @@ -44,6 +49,11 @@ def __init__(
hf_config: Hugging Face model configuration
dtype: Data type for model parameters
parallel_states: Parallel processing states, or None to use default

MoE HuggingFace **export** layout (Megatron → HF state_dict) follows the installed
transformers version (stacked ``gate_up_proj`` / ``down_proj`` when ≥5).
**Load** infers stacked vs per-expert ``experts.{{i}}.*`` keys from the
safetensors index.
"""
self.hf_config = hf_config
self.extra_args = {}
Expand All @@ -64,6 +74,16 @@ def __init__(
# Some moe models require multiple weights to be combined into one,
# such as qwen3vl. It will cache it into this buff until all weights are collected.
self.export_weights_buff = {}
# "load" while resolving HF names in load_weights (checkpoint-driven layout);
# "export" for Megatron→HF (target layout from transformers version).
self._hf_moe_mapping_phase = "export"

def _hf_moe_stacked_layout(self) -> bool:
"""True → fused ``gate_up_proj`` / ``down_proj``; False → ``experts.{i}.gate_proj`` …"""
if self._hf_moe_mapping_phase == "load":
keys = set(self.safetensor_io.index.keys())
return hf_moe_stacked_layout_from_checkpoint_keys(keys)
return hf_moe_stacked_layout_default_from_transformers_version()

def get_model(
self,
Expand Down Expand Up @@ -166,7 +186,18 @@ def load_weights(
weights_path: Path to the weights file or Hugging Face model identifier
"""
self.safetensor_io = self._get_safetensor_io(weights_path)
self._hf_moe_mapping_phase = "load"
try:
self._load_weights_impl(models, weights_path, memory_efficient)
finally:
self._hf_moe_mapping_phase = "export"

def _load_weights_impl(
self,
models: list[torch.nn.Module],
weights_path: str,
memory_efficient: bool = False,
) -> None:
for i, model in enumerate(models):
# map local weight names to global weight names
local_to_global_map = self._weight_name_mapping_mcore_local_to_global(model)
Expand Down Expand Up @@ -531,10 +562,10 @@ def save_weights(

def set_extra_args(self, **kwargs):
"""
Set additional configuration arguments.
Set additional configuration arguments for model-specific bridge behavior.

Args:
**kwargs: Key-value pairs of additional arguments
**kwargs: Key-value pairs of additional arguments.
"""
self.extra_args.update(kwargs)
self.config = self._build_config()
Expand All @@ -546,6 +577,7 @@ def export_weights(
assert (
len(self.export_weights_buff) == 0
), f"should be empty {self.export_weights_buff=}"
self._hf_moe_mapping_phase = "export"
models = [unwrap_model(model) for model in models]

def get_model_chunk_generator():
Expand Down Expand Up @@ -609,7 +641,7 @@ def get_model_chunk_generator():
name = broadcast_str_from_megatron_pp(name)
broad_pp_param = broadcast_from_megatron_pp(param)

# EP
# EP; HF key layout from :meth:`_hf_moe_stacked_layout` (export phase).
if ".mlp.experts.linear_fc" in name:
num_experts = self.config.num_moe_experts
num_experts_per_rank = num_experts // self.mpu.ep_size
Expand Down Expand Up @@ -707,6 +739,7 @@ def export_weights_without_gather(
tp_size is 0: is not tp tensor
ep_size is 0: is not ep tensor
"""
self._hf_moe_mapping_phase = "export"
models = [unwrap_model(model) for model in models]

def get_model_chunk_generator():
Expand Down Expand Up @@ -763,7 +796,7 @@ def get_model_chunk_generator():

assert iter_pp_rank == self.mpu.pp_rank

# EP
# EP: see export_weights comment on MoE HF layout vs transformers version.
if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1:
num_experts = self.config.num_moe_experts
num_experts_per_rank = num_experts // self.mpu.ep_size
Expand Down
59 changes: 49 additions & 10 deletions mbridge/core/safetensor_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,37 @@ def _mapping_weight_names_new2old(
mapping_hf_weight_names = {k: k for k in hf_weight_names}
return hf_weight_names, mapping_hf_weight_names

def _resolve_hf_weight_key(self, requested_key: str, available_keys: set[str]) -> str:
"""Resolve a requested HF key to an actual checkpoint key.

Example:
- requested: ``model.layers.12.mlp.experts.gate_up_proj``
- stored in index: ``model.layers.12.mlp.experts.gate_up_proj.weight``
This helper resolves that mismatch (and the reverse case).
"""
candidate_keys = [requested_key]
if not requested_key.endswith(".weight"):
candidate_keys.append(f"{requested_key}.weight")
else:
candidate_keys.append(requested_key[: -len(".weight")])
# HF sometimes nests Parameter-like leaves (e.g. …experts.gate_up_proj → …gate_up_proj.weight)
for base in list(candidate_keys):
if base.endswith("gate_up_proj") or base.endswith("down_proj"):
candidate_keys.append(f"{base}.weight")
seen: set[str] = set()
ordered_candidates: list[str] = []
for c in candidate_keys:
if c not in seen:
seen.add(c)
ordered_candidates.append(c)
for c in ordered_candidates:
if c in available_keys:
return c
raise KeyError(
f"Safetensors index has no key matching {requested_key!r} "
f"(tried: {ordered_candidates})."
)

def load_some_hf_weight(self, hf_weight_names: list[str]) -> dict:
hf_weight_names, mapping_hf_weight_names = self._mapping_weight_names_old2new(
hf_weight_names
Expand All @@ -66,28 +97,36 @@ def load_some_hf_weight(self, hf_weight_names: list[str]) -> dict:
ret = {}

if index:
available = set(index.keys())
file_to_weight_map = defaultdict(list)
for name in hf_weight_names:
filename = index[name]
file_to_weight_map[filename].append(name)
for filename, weight_names in file_to_weight_map.items():
resolved = self._resolve_hf_weight_key(name, available)
filename = index[resolved]
file_to_weight_map[filename].append((name, resolved))
for filename, name_pairs in file_to_weight_map.items():
safetensor_file = os.path.join(hf_dir, filename)
with safe_open(safetensor_file, framework="pt", device="cpu") as f:
for name in weight_names:
ret[name] = f.get_tensor(name)
for logical, resolved in name_pairs:
ret[logical] = f.get_tensor(resolved)
return {mapping_hf_weight_names[k]: v for k, v in ret.items()}
# Search all safetensors files
safetensor_files = glob(os.path.join(hf_dir, "*.safetensors"))
# If there are safetensors files
if safetensor_files:
# Iterate through each safetensors file
remaining = set(hf_weight_names)
for safetensor_file in safetensor_files:
if not remaining:
break
with safe_open(safetensor_file, framework="pt", device="cpu") as f:
to_load = set(hf_weight_names) & set(f.keys())
if to_load:
for name in to_load:
ret[name] = f.get_tensor(name)
# print(f"{name} {ret[name].shape}")
keys = set(f.keys())
for name in list(remaining):
try:
resolved = self._resolve_hf_weight_key(name, keys)
ret[name] = f.get_tensor(resolved)
remaining.discard(name)
except KeyError:
pass
if len(ret) != len(hf_weight_names):
raise ValueError(
f"Weights {set(hf_weight_names)-set(ret.keys())} not found in safetensors files in {hf_dir}"
Expand Down
70 changes: 69 additions & 1 deletion mbridge/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ class DeepseekV3Bridge(LLMBridge):
],
}

_MLP_MAPPING_MOE_FUSED = {
"decoder.layers.{layer_number}.mlp.experts.linear_fc1.weight": [
"model.layers.{layer_number}.mlp.experts.gate_up_proj",
],
"decoder.layers.{layer_number}.mlp.experts.linear_fc2.weight": [
"model.layers.{layer_number}.mlp.experts.down_proj",
],
}

_ATTENTION_MAPPING = {
"input_layernorm.weight": [
"model.layers.{layer_number}.input_layernorm.weight"
Expand Down Expand Up @@ -322,8 +331,58 @@ def _weight_to_hf_format(
):
hf_names = self._SHARED_STATE_DICT_MAPPING[mcore_weights_name]
return hf_names, [mcore_weights] * len(hf_names)

hf_names = self._weight_name_mapping_mcore_to_hf(mcore_weights_name)
if (
self._hf_moe_stacked_layout()
and "mlp.experts.linear_fc" in mcore_weights_name
and len(hf_names) == 1
):
experts_key = hf_names[0]
experts_idx = int(mcore_weights_name.split(".weight")[-1])
if experts_key not in self.export_weights_buff:
self.export_weights_buff[experts_key] = {}
assert experts_idx not in self.export_weights_buff[experts_key]
self.export_weights_buff[experts_key][experts_idx] = mcore_weights
if (
len(self.export_weights_buff[experts_key])
< self.config.num_moe_experts
):
return [], []
ordered = [
self.export_weights_buff[experts_key].pop(i)
for i in range(self.config.num_moe_experts)
]
self.export_weights_buff.pop(experts_key)
return [experts_key], [torch.stack(ordered)]

return super()._weight_to_hf_format(mcore_weights_name, mcore_weights)

def _weight_to_mcore_format(
self, mcore_weights_name: str, hf_weights: list[torch.Tensor]
) -> torch.Tensor:
if (
hasattr(self, "dtype")
and self.dtype is not None
and "expert_bias" not in mcore_weights_name
):
hf_weights = [
w.to(self.dtype) if w.dtype != self.dtype else w for w in hf_weights
]
if (
self._hf_moe_stacked_layout()
and "mlp.experts.linear_fc" in mcore_weights_name
and len(hf_weights) == 1
):
local_experts_idx = int(mcore_weights_name.split(".weight")[-1])
num_experts = self.config.num_moe_experts
num_experts_per_rank = num_experts // self.mpu.ep_size
experts_idx = (
local_experts_idx + num_experts_per_rank * self.mpu.ep_rank
)
return hf_weights[0][experts_idx].clone().contiguous()
return super()._weight_to_mcore_format(mcore_weights_name, hf_weights)

def _get_safetensor_io(self, weights_path: str):
if self.dtype == torch.bfloat16:
from .ext.deepseek_v3.dequant_fp8_safetensor_io import (
Expand All @@ -332,10 +391,19 @@ def _get_safetensor_io(self, weights_path: str):

return DequantFP8SafeTensorIO(self._get_actual_hf_path(weights_path))
else:
raise NotImplemented("only support bfloat16 for now")
raise NotImplementedError("only support bfloat16 for now")

def _weight_name_mapping_mlp(self, name: str) -> list[str]:
layer_number = name.split(".")[2]
if self._hf_moe_stacked_layout() and "mlp.experts.linear_fc" in name:
split_name = name.split(".")
split_name[2] = "{layer_number}"
key = ".".join(split_name)
pre, _expert = key.split(".weight", 1)
stacked_key = pre + ".weight"
mapping_names = self._MLP_MAPPING_MOE_FUSED[stacked_key]
return [x.format(layer_number=layer_number) for x in mapping_names]

convert_names = []
for keyword, mapping_names in self._MLP_MAPPING.items():
if keyword in name:
Expand Down
9 changes: 9 additions & 0 deletions mbridge/models/glm4_vl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ class Glm4VLBridgeMoe(Glm4VLBridgeBase):
],
}

_MLP_MAPPING_MOE_FUSED = {
"language_model.decoder.layers.{layer_number}.mlp.experts.linear_fc1.weight": [
"model.language_model.layers.{layer_number}.mlp.experts.gate_up_proj",
],
"language_model.decoder.layers.{layer_number}.mlp.experts.linear_fc2.weight": [
"model.language_model.layers.{layer_number}.mlp.experts.down_proj",
],
}

def _build_config(self):
kwargs = {}
kwargs["image_start_token_id"] = self.hf_config.image_start_token_id
Expand Down
49 changes: 49 additions & 0 deletions mbridge/models/glm4_vl/base_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ def _weight_name_mapping_attention(self, name: str) -> list[str]:
# adapted from deepseek v3
def _weight_name_mapping_mlp(self, name: str) -> list[str]:
layer_number = name.split(".")[3]
fused_moe_weights = getattr(self, "_MLP_MAPPING_MOE_FUSED", None)
if (
fused_moe_weights
and self._hf_moe_stacked_layout()
and "mlp.experts.linear_fc" in name
):
split_name = name.split(".")
split_name[3] = "{layer_number}"
key = ".".join(split_name)
pre, _expert = key.split(".weight", 1)
stacked_key = pre + ".weight"
mapping_names = fused[stacked_key]
return [x.format(layer_number=layer_number) for x in mapping_names]

convert_names = []
for keyword, mapping_names in self._MLP_MAPPING.items():
if keyword in name:
Expand Down Expand Up @@ -143,6 +157,28 @@ def _weight_to_hf_format(
"""
hf_names = self._weight_name_mapping_mcore_to_hf(mcore_weights_name)
if len(hf_names) == 1:
if (
getattr(self, "_MLP_MAPPING_MOE_FUSED", None)
and self._hf_moe_stacked_layout()
and ".mlp.experts.linear_fc" in mcore_weights_name
):
experts_key = hf_names[0]
experts_idx = int(mcore_weights_name.split(".weight")[-1])
if experts_key not in self.export_weights_buff:
self.export_weights_buff[experts_key] = {}
assert experts_idx not in self.export_weights_buff[experts_key]
self.export_weights_buff[experts_key][experts_idx] = mcore_weights
if (
len(self.export_weights_buff[experts_key])
< self.config.num_moe_experts
):
return [], []
ordered = [
self.export_weights_buff[experts_key].pop(i)
for i in range(self.config.num_moe_experts)
]
self.export_weights_buff.pop(experts_key)
return [experts_key], [torch.stack(ordered)]
return [hf_names[0]], [mcore_weights]
if (
"self_attention.linear_qkv." in mcore_weights_name
Expand Down Expand Up @@ -207,6 +243,19 @@ def _weight_to_mcore_format(
Raises:
NotImplementedError: If the parameter name is unsupported
"""
if (
getattr(self, "_MLP_MAPPING_MOE_FUSED", None)
and self._hf_moe_stacked_layout()
and ".mlp.experts.linear_fc" in mcore_weights_name
and len(hf_weights) == 1
):
local_experts_idx = int(mcore_weights_name.split(".weight")[-1])
num_experts = self.config.num_moe_experts
num_experts_per_rank = num_experts // self.mpu.ep_size
experts_idx = (
local_experts_idx + num_experts_per_rank * self.mpu.ep_rank
)
return hf_weights[0][experts_idx].clone().contiguous()
if len(hf_weights) == 1:
return hf_weights[0]
if (
Expand Down
Loading
Loading