diff --git a/mergekit/_data/architectures/kormo.json b/mergekit/_data/architectures/kormo.json new file mode 100644 index 00000000..17ceb9d6 --- /dev/null +++ b/mergekit/_data/architectures/kormo.json @@ -0,0 +1,57 @@ +{ + "model_type": "kormo", + "architectures": [ + "KORMoForCausalLM" + ], + "pre_weights": [ + { + "name": "model.embed_tokens.weight", + "is_embed": true + } + ], + "post_weights": [ + { + "name": "model.norm.weight" + }, + { + "name": "lm_head.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] + } + ], + "num_layers_config_key": "num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name": "model.layers.${layer_index}.pre_attention_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.o_proj.weight" + }, + { + "name": "model.layers.${layer_index}.pre_mlp_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.gate_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.up_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.down_proj.weight" + } + ] + } +} \ No newline at end of file diff --git a/mergekit/_data/architectures/kormo_moe.json b/mergekit/_data/architectures/kormo_moe.json new file mode 100644 index 00000000..7102795e --- /dev/null +++ b/mergekit/_data/architectures/kormo_moe.json @@ -0,0 +1,48 @@ +{ + "model_type": "kormo_moe", + "architectures": [ + "KORMoMoeForCausalLM" + ], + "pre_weights": [ + { + "name": "model.embed_tokens.weight", + "is_embed": true + } + ], + "post_weights": [ + { + "name": "model.norm.weight" + }, + { + "name": "lm_head.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] + } + ], + "num_layers_config_key": "num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name": "model.layers.${layer_index}.pre_attention_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.o_proj.weight" + }, + { + "name": "model.layers.${layer_index}.pre_mlp_layernorm.weight" + } + ] + } +} \ No newline at end of file diff --git a/mergekit/architecture/__init__.py b/mergekit/architecture/__init__.py index e430b185..a160227b 100644 --- a/mergekit/architecture/__init__.py +++ b/mergekit/architecture/__init__.py @@ -20,6 +20,7 @@ from mergekit.architecture.moe_defs import ( MixtralModuleArchitecture, Qwen3MoeModuleArchitecture, + KORMoMoeModuleArchitecture ) from mergekit.options import MergeOptions @@ -34,6 +35,7 @@ def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]: if len(config.architectures) != 1: raise RuntimeError("More than one architecture in config?") + arch_name = config.architectures[0] if arch_name == MixtralModuleArchitecture.ARCHITECTURE_NAME: @@ -50,6 +52,13 @@ def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture architectures=[arch_name], model_type="qwen3_moe", ) + elif arch_name == KORMoMoeModuleArchitecture.ARCHITECTURE_NAME: # 추가 + module = KORMoMoeModuleArchitecture.from_config(config) + return ModelArchitecture( + modules={"default": ModuleDefinition(architecture=module)}, + architectures=[arch_name], + model_type="kormo_moe", + ) elif arch_name in NAME_TO_ARCH: candidates = list(NAME_TO_ARCH[arch_name]) if len(candidates) == 1: diff --git a/mergekit/architecture/moe_defs.py b/mergekit/architecture/moe_defs.py index acb93ce4..8e954909 100644 --- a/mergekit/architecture/moe_defs.py +++ b/mergekit/architecture/moe_defs.py @@ -85,19 +85,96 @@ def num_layers_config_key(self) -> str: def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: + num_experts = self.num_experts prefix = f"model.layers.{index}" tensor_names = [] - for expert_idx in range(self.num_experts): + + # Expert weights 추가 + for expert_idx in range(num_experts): for param in ("up_proj", "gate_proj", "down_proj"): tensor_names.append( prefix + f".mlp.experts.{expert_idx}.{param}.weight" ) + + # Shared expert weights 추가 - 이 부분이 중요! + for param in ("up_proj", "gate_proj", "down_proj"): + tensor_names.append( + prefix + f".mlp.shared_expert.{param}.weight" + ) + + # Gate weights 추가 tensor_names.append(prefix + ".mlp.gate.weight") + tensor_names.append(prefix + ".mlp.shared_expert_gate.weight") + res = [] for name in tensor_names: res.append(WeightInfo(name=name)) + + # 기존 Qwen3 weights 중에서 MLP를 제외한 것들 추가 for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config): if ".mlp." in weight_info.name: continue res.append(weight_info) + return res + +# 파일 상단 import 부분에 추가 +KORMO_INFO = NAME_TO_ARCH["KORMoForCausalLM"][0] +KORMO_MODULE_ARCH = KORMO_INFO.modules["default"].architecture + + +class KORMoMoeModuleArchitecture(ModuleArchitecture, BaseModel): + ARCHITECTURE_NAME: ClassVar[str] = "KORMoMoeForCausalLM" + num_experts: int + + def name(self) -> str: + return "kormo_moe" + + @classmethod + def from_config(cls, config: PretrainedConfig): + return KORMoMoeModuleArchitecture(num_experts=config.num_experts) + + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return KORMO_MODULE_ARCH.pre_weights(config) + + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return KORMO_MODULE_ARCH.post_weights(config) + + def num_layers_config_key(self) -> str: + return KORMO_MODULE_ARCH.num_layers_config_key() + + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + num_experts = self.num_experts + prefix = f"model.layers.{index}" + tensor_names = [] + + # Expert weights 추가 + for expert_idx in range(num_experts): + for param in ("gate_proj", "up_proj", "down_proj"): + tensor_names.append( + prefix + f".mlp.experts.{expert_idx}.{param}.weight" + ) + + # Shared expert weights 추가 + for param in ("gate_proj", "up_proj", "down_proj"): + tensor_names.append( + prefix + f".mlp.shared_expert.{param}.weight" + ) + + # Gate weights 추가 + tensor_names.append(prefix + ".mlp.gate.weight") + tensor_names.append(prefix + ".mlp.shared_expert_gate.weight") + + res = [] + for name in tensor_names: + res.append(WeightInfo(name=name)) + + # 기존 KORMo weights 중에서 MLP를 제외한 것들 추가 + for weight_info in KORMO_MODULE_ARCH.layer_weights(index, config): + if ".mlp." in weight_info.name: + continue + res.append(weight_info) + + return res \ No newline at end of file diff --git a/mergekit/moe/__init__.py b/mergekit/moe/__init__.py index bc1cf067..fd2d108f 100644 --- a/mergekit/moe/__init__.py +++ b/mergekit/moe/__init__.py @@ -6,6 +6,15 @@ ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()] +# Qwen3MoE를 먼저 추가 +try: + from mergekit.moe.qwen3 import Qwen3MoE +except ImportError: + pass +else: + ALL_OUTPUT_ARCHITECTURES.append(Qwen3MoE()) + +# QwenMoE를 나중에 추가 (fallback용) try: from mergekit.moe.qwen import QwenMoE except ImportError: @@ -13,7 +22,15 @@ else: ALL_OUTPUT_ARCHITECTURES.append(QwenMoE()) +# KORMo MoE 추가 +try: + from mergekit.moe.kormo import KORMoMoE +except ImportError: + pass +else: + ALL_OUTPUT_ARCHITECTURES.append(KORMoMoE()) + __all__ = [ "ALL_OUTPUT_ARCHITECTURES", "MoEOutputArchitecture", -] +] \ No newline at end of file diff --git a/mergekit/moe/_architectures/configuration_kormo_moe.py b/mergekit/moe/_architectures/configuration_kormo_moe.py new file mode 100644 index 00000000..ab66afec --- /dev/null +++ b/mergekit/moe/_architectures/configuration_kormo_moe.py @@ -0,0 +1,86 @@ +# <저장된_모델_경로>/configuration_kormo_moe.py + +from transformers import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class KORMoMoeConfig(PretrainedConfig): + model_type = "kormo_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=112576, + hidden_size=6144, + intermediate_size=21504, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=500000.0, + attention_bias=False, + attention_dropout=0.0, + rope_scaling=None, + mlp_bias=False, + head_dim=128, + # MoE specific + num_experts=2, + num_experts_per_tok=2, + moe_intermediate_size=None, + shared_expert_intermediate_size=None, + norm_topk_prob=True, + decoder_sparse_step=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.mask_type = None + + # MoE specific + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.norm_topk_prob = norm_topk_prob + self.decoder_sparse_step = decoder_sparse_step + + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mergekit/moe/_architectures/modeling_kormo_moe.py b/mergekit/moe/_architectures/modeling_kormo_moe.py new file mode 100644 index 00000000..07e5fbd3 --- /dev/null +++ b/mergekit/moe/_architectures/modeling_kormo_moe.py @@ -0,0 +1,574 @@ +from typing import Callable, List, Optional, Tuple, Union, Dict +import torch +from torch import nn +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import can_return_tuple, logging +from .configuration_kormo_moe import KORMoMoeConfig + +logger = logging.get_logger(__name__) + + +@use_kernel_forward_from_hub("RMSNorm") +class RMSNorm(nn.Module): + """KORMoRMSNorm is equivalent to T5LayerNorm""" + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: KORMoMoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("MLP") +class MLP(nn.Module): + """Basic MLP for experts""" + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class MoEGate(nn.Module): + """MoE Gating mechanism""" + def __init__(self, config: KORMoMoeConfig): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + + self.linear = nn.Linear(config.hidden_size, config.num_experts, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states: [batch_size, seq_len, hidden_size] + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # Compute router logits + router_logits = self.linear(hidden_states) # [batch_size * seq_len, num_experts] + + # Get routing weights + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + + # Normalize routing weights if needed + if self.norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + + routing_weights = routing_weights.to(hidden_states.dtype) + + return routing_weights, selected_experts + + +class KORMoSparseMoeBlock(nn.Module): + """KORMo Sparse MoE Block""" + def __init__(self, config: KORMoMoeConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + + self.gate = MoEGate(config) + self.experts = nn.ModuleList([ + MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(self.num_experts) + ]) + + # Shared expert (선택사항) + self.shared_expert = None + self.shared_expert_gate = None + if config.shared_expert_intermediate_size is not None: + self.shared_expert = MLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + routing_weights, selected_experts = self.gate(hidden_states) + final_hidden_states = torch.zeros_like(hidden_states_flat) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states_flat[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim) + + # Shared expert 추가 + if self.shared_expert is not None: + hidden_states_flat = hidden_states.view(-1, hidden_dim) + shared_output = self.shared_expert(hidden_states_flat) + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states_flat)) + final_hidden_states = final_hidden_states + (shared_gate * shared_output).reshape(batch_size, seq_len, hidden_dim) + + return final_hidden_states + + +class DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: KORMoMoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Attention(config=config, layer_idx=layer_idx) + self.mlp = KORMoSparseMoeBlock(config) + self.pre_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.pre_attention_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MoE layer + residual = hidden_states + hidden_states = self.pre_mlp_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class RotaryEmbedding(nn.Module): + def __init__(self, config: KORMoMoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos, sin + + +class KORMoMoePreTrainedModel(PreTrainedModel): + config_class = KORMoMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, RMSNorm): + module.weight.data.fill_(1.0) + + +class KORMoMoeModel(KORMoMoePreTrainedModel): + def __init__(self, config: KORMoMoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KORMoMoeForCausalLM(KORMoMoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = KORMoMoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: int = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/mergekit/moe/kormo.py b/mergekit/moe/kormo.py new file mode 100644 index 00000000..6a9c8d18 --- /dev/null +++ b/mergekit/moe/kormo.py @@ -0,0 +1,204 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +import json +import logging +import os +import shutil +from pathlib import Path +from typing import List, Optional + +import torch +import tqdm +import transformers + +from mergekit.architecture import arch_info_for_config +from mergekit.architecture.json_definitions import NAME_TO_ARCH +from mergekit.moe.arch import MoEOutputArchitecture +from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype +from mergekit.moe.config import MoEMergeConfig +from mergekit.options import MergeOptions + +KORMO_INFO = NAME_TO_ARCH["KORMoForCausalLM"][0] + + +class KORMoMoE(MoEOutputArchitecture): + def name(self) -> str: + return "KORMo MoE" + + def supports_config( + self, + config: MoEMergeConfig, + explain: bool = False, + trust_remote_code: bool = False, + ) -> bool: + model_types = [] + for model_ref in ( + [config.base_model] + + [e.source_model for e in config.experts] + + [e.source_model for e in (config.shared_experts or [])] + ): + model_cfg = model_ref.config(trust_remote_code=trust_remote_code) + model_types.append(model_cfg.model_type) + + if len(set(model_types)) != 1: + if explain: + logging.warning( + "KORMo MoE requires all input models to have the same architecture" + ) + return False + + if model_types[0] != "kormo": + if explain: + logging.warning( + "KORMo MoE requires input models to be KORMo architecture" + ) + return False + + return True + + def _generate_config( + self, + base_config: transformers.PretrainedConfig, + num_experts: int, + num_shared_experts: int = 0, + experts_per_token: Optional[int] = None, + ) -> dict: + res = base_config.to_dict() + res["architectures"] = ["KORMoMoeForCausalLM"] + res["model_type"] = "kormo_moe" + res["num_experts"] = num_experts + res["num_experts_per_tok"] = experts_per_token or 2 + res["decoder_sparse_step"] = 1 + res["norm_topk_prob"] = True + res["moe_intermediate_size"] = res["intermediate_size"] + + # auto_map 추가 - 커스텀 모델 로딩을 위해 필수 + res["auto_map"] = { + "AutoConfig": "configuration_kormo_moe.KORMoMoeConfig", + "AutoModelForCausalLM": "modeling_kormo_moe.KORMoMoeForCausalLM" + } + + if num_shared_experts > 0: + res["shared_expert_intermediate_size"] = res["intermediate_size"] + + if (res["num_experts"] & (res["num_experts"] - 1)) != 0: + logging.warning( + f"Your model has {res['num_experts']} experts, which is " + "not a power of two. The model will not be usable in llama.cpp." + ) + return res + + def write_model( + self, + out_path: str, + config: MoEMergeConfig, + merge_options: MergeOptions, + router_weights: List[torch.Tensor], + shared_router_weights: Optional[List[torch.Tensor]] = None, + ): + base_model = config.base_model + base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) + + # 출력 디렉토리 생성 + os.makedirs(out_path, exist_ok=True) + + out_dtype = select_dtype(config, base_cfg) + out_cfg = self._generate_config( + base_cfg, + len(config.experts), + len(config.shared_experts or []), + config.experts_per_token, + ) + if out_dtype is not None: + out_cfg["torch_dtype"] = str(out_dtype).removeprefix("torch.") + + with open(os.path.join(out_path, "config.json"), "w", encoding="utf-8") as f: + json.dump(out_cfg, f, indent=4) + + # Copy custom model files to output directory + arch_dir = Path(__file__).parent / "_architectures" + for model_file in ["configuration_kormo_moe.py", "modeling_kormo_moe.py"]: + src_file = arch_dir / model_file + if src_file.exists(): + shutil.copy2(src_file, out_path) + logging.info(f"Copied {model_file} to {out_path}") + else: + logging.warning(f"Model file {model_file} not found at {src_file}") + + shared_def = config.shared_experts[0] if config.shared_experts else None + + loaders, base_loader, writer = initialize_io(config, out_path, merge_options) + shared_loader = loaders.get(shared_def.source_model) if shared_def else base_loader + + for weight_info in tqdm.tqdm( + KORMO_INFO.all_weights(base_cfg), + desc="Weights", + ): + tensor_name = weight_info.name + if ".mlp." in tensor_name: + # Expert weights 복사 + for expert_idx, expert in enumerate(config.experts): + expert_name = tensor_name.replace( + ".mlp.", f".mlp.experts.{expert_idx}." + ) + expert_loader = loaders.get(expert.source_model) + copy_tensor_out( + weight_info, + expert_loader, + writer, + expert=expert, + is_residual="down_proj" in tensor_name, + output_name=expert_name, + out_dtype=out_dtype, + clone=merge_options.clone_tensors, + ) + + # Shared expert weights 복사 - shared_experts가 있을 때만! + if shared_def is not None: + shared_expert_name = tensor_name.replace(".mlp.", ".mlp.shared_expert.") + copy_tensor_out( + weight_info, + shared_loader, + writer, + expert=shared_def, + is_residual="down_proj" in tensor_name, + output_name=shared_expert_name, + out_dtype=out_dtype, + clone=merge_options.clone_tensors, + ) + else: + # 일반 weights 복사 + copy_tensor_out( + weight_info, + base_loader, + writer, + out_dtype=out_dtype, + clone=merge_options.clone_tensors, + ) + + # Router weights 저장 - 모든 weight 처리 후 별도로 저장 + for layer_idx, weight in enumerate( + tqdm.tqdm(router_weights, desc="Router weights") + ): + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.gate.linear.weight", + weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + # Shared expert gate weights 저장 - shared_experts가 있을 때만! + if shared_def is not None: + if shared_router_weights is not None and len(shared_router_weights) > layer_idx: + shared_weight = shared_router_weights[layer_idx] + else: + # shared_router_weights가 없으면 dummy weight 생성 + shared_weight = torch.zeros_like(weight[:1, :]) # [1, hidden_size] + + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight", + shared_weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + writer.finalize() \ No newline at end of file diff --git a/mergekit/moe/qwen.py b/mergekit/moe/qwen.py index 46cc820c..356612c1 100644 --- a/mergekit/moe/qwen.py +++ b/mergekit/moe/qwen.py @@ -59,10 +59,10 @@ def supports_config( "Qwen MoE requires all input models to have the same architecture" ) return False - if model_types[0] not in ("llama", "mistral", "qwen2"): + if model_types[0] not in ("llama", "mistral", "qwen2", "qwen3"): if explain: logging.warning( - "Qwen MoE requires all input models to be Qwen2, Llama or Mistral models" + "Qwen MoE requires input models to be Llama, Mistral, Qwen2, or Qwen3 architecture" ) return False return True diff --git a/mergekit/moe/qwen3.py b/mergekit/moe/qwen3.py new file mode 100644 index 00000000..eceaeca6 --- /dev/null +++ b/mergekit/moe/qwen3.py @@ -0,0 +1,188 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +import json +import logging +import os +from typing import List, Optional + +import torch +import tqdm +import transformers + +from mergekit.architecture import arch_info_for_config +from mergekit.architecture.json_definitions import NAME_TO_ARCH +from mergekit.moe.arch import MoEOutputArchitecture +from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype +from mergekit.moe.config import MoEMergeConfig +from mergekit.options import MergeOptions + +QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0] + + +class Qwen3MoE(MoEOutputArchitecture): + def name(self) -> str: + return "Qwen3 MoE" + + def supports_config( + self, + config: MoEMergeConfig, + explain: bool = False, + trust_remote_code: bool = False, + ) -> bool: + model_types = [] + for model_ref in ( + [config.base_model] + + [e.source_model for e in config.experts] + + [e.source_model for e in (config.shared_experts or [])] + ): + model_cfg = model_ref.config(trust_remote_code=trust_remote_code) + model_types.append(model_cfg.model_type) + + if len(set(model_types)) != 1: + if explain: + logging.warning( + "Qwen3 MoE requires all input models to have the same architecture" + ) + return False + + if model_types[0] != "qwen3": + if explain: + logging.warning( + "Qwen3 MoE requires input models to be Qwen3 architecture" + ) + return False + + return True + + def _generate_config( + self, + base_config: transformers.PretrainedConfig, + num_experts: int, + num_shared_experts: int = 0, + experts_per_token: Optional[int] = None, + ) -> dict: + res = base_config.to_dict() + res["architectures"] = ["Qwen3MoeForCausalLM"] + res["model_type"] = "qwen3_moe" + res["num_experts"] = num_experts + res["num_experts_per_tok"] = experts_per_token or 2 + res["decoder_sparse_step"] = 1 + res["norm_topk_prob"] = True + res["sliding_window"] = None + res["use_sliding_window"] = False + res["moe_intermediate_size"] = res["intermediate_size"] + + if num_shared_experts > 0: + res["shared_expert_intermediate_size"] = res["intermediate_size"] + + if (res["num_experts"] & (res["num_experts"] - 1)) != 0: + logging.warning( + f"Your model has {res['num_experts']} experts, which is " + "not a power of two. The model will not be usable in llama.cpp." + ) + return res + + def write_model( + self, + out_path: str, + config: MoEMergeConfig, + merge_options: MergeOptions, + router_weights: List[torch.Tensor], + shared_router_weights: Optional[List[torch.Tensor]] = None, + ): + base_model = config.base_model + base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) + + # 출력 디렉토리 생성 + os.makedirs(out_path, exist_ok=True) + + out_dtype = select_dtype(config, base_cfg) + out_cfg = self._generate_config( + base_cfg, + len(config.experts), + len(config.shared_experts or []), + config.experts_per_token, + ) + if out_dtype is not None: + out_cfg["torch_dtype"] = str(out_dtype).removeprefix("torch.") + + with open(os.path.join(out_path, "config.json"), "w", encoding="utf-8") as f: + json.dump(out_cfg, f, indent=4) + + shared_def = config.shared_experts[0] if config.shared_experts else None + + loaders, base_loader, writer = initialize_io(config, out_path, merge_options) + shared_loader = loaders.get(shared_def.source_model) if shared_def else base_loader + + for weight_info in tqdm.tqdm( + QWEN3_INFO.all_weights(base_cfg), + desc="Weights", + ): + tensor_name = weight_info.name + if ".mlp." in tensor_name: + # Expert weights 복사 + for expert_idx, expert in enumerate(config.experts): + expert_name = tensor_name.replace( + ".mlp.", f".mlp.experts.{expert_idx}." + ) + expert_loader = loaders.get(expert.source_model) + copy_tensor_out( + weight_info, + expert_loader, + writer, + expert=expert, + is_residual="down_proj" in tensor_name, + output_name=expert_name, + out_dtype=out_dtype, + clone=merge_options.clone_tensors, + ) + + # Shared expert weights 복사 - shared_experts가 있을 때만! + if shared_def is not None: + shared_expert_name = tensor_name.replace(".mlp.", ".mlp.shared_expert.") + copy_tensor_out( + weight_info, + shared_loader, + writer, + expert=shared_def, + is_residual="down_proj" in tensor_name, + output_name=shared_expert_name, + out_dtype=out_dtype, + clone=merge_options.clone_tensors, + ) + else: + # 일반 weights 복사 + copy_tensor_out( + weight_info, + base_loader, + writer, + out_dtype=out_dtype, + clone=merge_options.clone_tensors, + ) + + # Router weights 저장 + for layer_idx, weight in enumerate( + tqdm.tqdm(router_weights, desc="Router weights") + ): + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.gate.weight", + weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + # Shared expert gate weights 저장 - shared_experts가 있을 때만! + if shared_def is not None: + if shared_router_weights is not None and len(shared_router_weights) > layer_idx: + shared_weight = shared_router_weights[layer_idx] + else: + # shared_router_weights가 없으면 dummy weight 생성 + shared_weight = torch.zeros_like(weight[:1, :]) # [1, hidden_size] + + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight", + shared_weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + writer.finalize() \ No newline at end of file