|
| 1 | +# Copyright © 2026 Apple Inc. |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from functools import partial |
| 5 | +from typing import Any, Dict, List, Optional, Tuple, Union |
| 6 | + |
| 7 | +import mlx.core as mx |
| 8 | +import mlx.nn as nn |
| 9 | +from mlx.nn.layers.distributed import shard_linear |
| 10 | + |
| 11 | +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention |
| 12 | +from .cache import KVCache, RotatingKVCache |
| 13 | +from .rope_utils import initialize_rope |
| 14 | + |
| 15 | + |
| 16 | +@partial(mx.compile, shapeless=True) |
| 17 | +def _compute_gate(query: mx.array, weight: mx.array, bias: mx.array) -> mx.array: |
| 18 | + gate_logits = query @ weight[:, None, :].swapaxes(-1, -2) |
| 19 | + gate_logits = gate_logits + bias[..., None, None] |
| 20 | + return mx.sigmoid(gate_logits) |
| 21 | + |
| 22 | + |
| 23 | +@partial(mx.compile, shapeless=True) |
| 24 | +def _silu_mul(gate: mx.array, up: mx.array) -> mx.array: |
| 25 | + return nn.silu(gate) * up |
| 26 | + |
| 27 | + |
| 28 | +@partial(mx.compile, shapeless=True) |
| 29 | +def _mix_attention( |
| 30 | + gate: mx.array, attn_global: mx.array, attn_local: mx.array |
| 31 | +) -> mx.array: |
| 32 | + return gate * attn_global + (1 - gate) * attn_local |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class ModelArgs(BaseModelArgs): |
| 37 | + model_type: str |
| 38 | + hidden_size: int |
| 39 | + num_hidden_layers: int |
| 40 | + intermediate_size: int |
| 41 | + num_attention_heads: int |
| 42 | + rms_norm_eps: float |
| 43 | + vocab_size: int |
| 44 | + head_dim: int |
| 45 | + num_key_value_heads: int |
| 46 | + max_position_embeddings: int = 131072 |
| 47 | + attention_bias: bool = False |
| 48 | + mlp_bias: bool = False |
| 49 | + rope_theta: float = 500000.0 |
| 50 | + rope_scaling: Optional[Dict[str, Union[float, str]]] = None |
| 51 | + tie_word_embeddings: bool = False |
| 52 | + loop_num: int = 2 |
| 53 | + loop_window_size: int = 64 |
| 54 | + |
| 55 | + |
| 56 | +class LoopGateProjection(nn.Module): |
| 57 | + def __init__(self, num_heads: int, head_dim: int): |
| 58 | + super().__init__() |
| 59 | + self.num_heads = num_heads |
| 60 | + self.head_dim = head_dim |
| 61 | + self.weight = mx.zeros((num_heads, head_dim)) |
| 62 | + self.bias = mx.zeros((num_heads,)) |
| 63 | + |
| 64 | + def __call__(self, query: mx.array) -> mx.array: |
| 65 | + return _compute_gate(query, self.weight, self.bias) |
| 66 | + |
| 67 | + |
| 68 | +class Attention(nn.Module): |
| 69 | + def __init__(self, args: ModelArgs): |
| 70 | + super().__init__() |
| 71 | + |
| 72 | + dim = args.hidden_size |
| 73 | + self.n_heads = n_heads = args.num_attention_heads |
| 74 | + self.n_kv_heads = n_kv_heads = args.num_key_value_heads |
| 75 | + self.head_dim = head_dim = args.head_dim |
| 76 | + self.scale = head_dim**-0.5 |
| 77 | + |
| 78 | + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) |
| 79 | + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) |
| 80 | + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) |
| 81 | + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) |
| 82 | + |
| 83 | + self.rope = initialize_rope( |
| 84 | + head_dim, |
| 85 | + args.rope_theta, |
| 86 | + traditional=False, |
| 87 | + scaling_config=args.rope_scaling, |
| 88 | + max_position_embeddings=args.max_position_embeddings, |
| 89 | + ) |
| 90 | + |
| 91 | + def get_qkv( |
| 92 | + self, x: mx.array, offset: int = 0 |
| 93 | + ) -> Tuple[mx.array, mx.array, mx.array]: |
| 94 | + B, L, _ = x.shape |
| 95 | + queries = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) |
| 96 | + keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
| 97 | + values = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
| 98 | + |
| 99 | + queries = self.rope(queries, offset=offset) |
| 100 | + keys = self.rope(keys, offset=offset) |
| 101 | + |
| 102 | + return queries, keys, values |
| 103 | + |
| 104 | + def attention( |
| 105 | + self, |
| 106 | + queries: mx.array, |
| 107 | + keys: mx.array, |
| 108 | + values: mx.array, |
| 109 | + mask: Optional[mx.array] = None, |
| 110 | + cache: Optional[Any] = None, |
| 111 | + ) -> mx.array: |
| 112 | + return scaled_dot_product_attention( |
| 113 | + queries, keys, values, cache=cache, scale=self.scale, mask=mask |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +class MLP(nn.Module): |
| 118 | + def __init__(self, args: ModelArgs): |
| 119 | + super().__init__() |
| 120 | + dim = args.hidden_size |
| 121 | + hidden_dim = args.intermediate_size |
| 122 | + self.gate_proj = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) |
| 123 | + self.down_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) |
| 124 | + self.up_proj = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) |
| 125 | + |
| 126 | + def __call__(self, x: mx.array) -> mx.array: |
| 127 | + return self.down_proj(_silu_mul(self.gate_proj(x), self.up_proj(x))) |
| 128 | + |
| 129 | + |
| 130 | +class TransformerBlock(nn.Module): |
| 131 | + def __init__(self, args: ModelArgs): |
| 132 | + super().__init__() |
| 133 | + self.self_attn = Attention(args) |
| 134 | + self.mlp = MLP(args) |
| 135 | + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| 136 | + self.post_attention_layernorm = nn.RMSNorm( |
| 137 | + args.hidden_size, eps=args.rms_norm_eps |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +class IQuestLoopCoderModel(nn.Module): |
| 142 | + def __init__(self, args: ModelArgs): |
| 143 | + super().__init__() |
| 144 | + assert args.loop_num == 2, f"Only loop_num=2 is supported, got {args.loop_num}" |
| 145 | + self.args = args |
| 146 | + self.vocab_size = args.vocab_size |
| 147 | + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) |
| 148 | + self.layers = [ |
| 149 | + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) |
| 150 | + ] |
| 151 | + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| 152 | + self.gate_projections = [ |
| 153 | + LoopGateProjection(args.num_attention_heads, args.head_dim) |
| 154 | + for _ in range(args.num_hidden_layers) |
| 155 | + ] |
| 156 | + self.loop_num = args.loop_num |
| 157 | + self.loop_window_size = args.loop_window_size |
| 158 | + |
| 159 | + def __call__( |
| 160 | + self, |
| 161 | + inputs: mx.array, |
| 162 | + cache: Optional[List[Any]] = None, |
| 163 | + ): |
| 164 | + B, L = inputs.shape[:2] |
| 165 | + h = self.embed_tokens(inputs) |
| 166 | + |
| 167 | + if cache is None: |
| 168 | + cache = [None] * (2 * len(self.layers)) |
| 169 | + |
| 170 | + offset = cache[0].offset if cache[0] is not None else 0 |
| 171 | + mask = create_attention_mask(h, cache[0]) |
| 172 | + window_mask = create_attention_mask( |
| 173 | + h, cache[len(self.layers)], window_size=self.loop_window_size |
| 174 | + ) |
| 175 | + |
| 176 | + loop1_kv = [] |
| 177 | + for layer, c in zip(self.layers, cache): |
| 178 | + h_norm = layer.input_layernorm(h) |
| 179 | + q1, k1, v1 = layer.self_attn.get_qkv(h_norm, offset) |
| 180 | + |
| 181 | + if c is not None: |
| 182 | + k1, v1 = c.update_and_fetch(k1, v1) |
| 183 | + loop1_kv.append((k1, v1)) |
| 184 | + |
| 185 | + out = layer.self_attn.attention(q1, k1, v1, mask, cache=c) |
| 186 | + r = layer.self_attn.o_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1)) |
| 187 | + h = h + r |
| 188 | + r = layer.mlp(layer.post_attention_layernorm(h)) |
| 189 | + h = h + r |
| 190 | + |
| 191 | + for layer, gate_proj, c, (k1, v1) in zip( |
| 192 | + self.layers, self.gate_projections, cache[len(self.layers) :], loop1_kv |
| 193 | + ): |
| 194 | + h_norm = layer.input_layernorm(h) |
| 195 | + q2, k2, v2 = layer.self_attn.get_qkv(h_norm, offset) |
| 196 | + gate = gate_proj(q2) |
| 197 | + attn_global = layer.self_attn.attention(q2, k1, v1, mask, cache=c) |
| 198 | + |
| 199 | + if c is not None: |
| 200 | + k2, v2 = c.update_and_fetch(k2, v2) |
| 201 | + attn_local = layer.self_attn.attention( |
| 202 | + q2, |
| 203 | + k2, |
| 204 | + v2, |
| 205 | + window_mask, |
| 206 | + cache=c, |
| 207 | + ) |
| 208 | + |
| 209 | + mixed = _mix_attention(gate, attn_global, attn_local) |
| 210 | + r = layer.self_attn.o_proj(mixed.transpose(0, 2, 1, 3).reshape(B, L, -1)) |
| 211 | + h = h + r |
| 212 | + r = layer.mlp(layer.post_attention_layernorm(h)) |
| 213 | + h = h + r |
| 214 | + |
| 215 | + return self.norm(h) |
| 216 | + |
| 217 | + |
| 218 | +class Model(nn.Module): |
| 219 | + def __init__(self, args: ModelArgs): |
| 220 | + super().__init__() |
| 221 | + self.args = args |
| 222 | + self.model_type = args.model_type |
| 223 | + self.model = IQuestLoopCoderModel(args) |
| 224 | + if not args.tie_word_embeddings: |
| 225 | + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) |
| 226 | + |
| 227 | + def __call__( |
| 228 | + self, |
| 229 | + inputs: mx.array, |
| 230 | + cache=None, |
| 231 | + ): |
| 232 | + out = self.model(inputs, cache) |
| 233 | + if self.args.tie_word_embeddings: |
| 234 | + out = self.model.embed_tokens.as_linear(out) |
| 235 | + else: |
| 236 | + out = self.lm_head(out) |
| 237 | + return out |
| 238 | + |
| 239 | + def shard(self, group: Optional[mx.distributed.Group] = None): |
| 240 | + group = group or mx.distributed.init() |
| 241 | + N = group.size() |
| 242 | + rank = group.rank() |
| 243 | + |
| 244 | + for i, layer in enumerate(self.model.layers): |
| 245 | + layer.self_attn.q_proj = shard_linear( |
| 246 | + layer.self_attn.q_proj, "all-to-sharded", group=group |
| 247 | + ) |
| 248 | + layer.self_attn.k_proj = shard_linear( |
| 249 | + layer.self_attn.k_proj, "all-to-sharded", group=group |
| 250 | + ) |
| 251 | + layer.self_attn.v_proj = shard_linear( |
| 252 | + layer.self_attn.v_proj, "all-to-sharded", group=group |
| 253 | + ) |
| 254 | + layer.self_attn.o_proj = shard_linear( |
| 255 | + layer.self_attn.o_proj, "sharded-to-all", group=group |
| 256 | + ) |
| 257 | + layer.self_attn.n_heads //= N |
| 258 | + layer.self_attn.n_kv_heads //= N |
| 259 | + |
| 260 | + layer.mlp.gate_proj = shard_linear( |
| 261 | + layer.mlp.gate_proj, "all-to-sharded", group=group |
| 262 | + ) |
| 263 | + layer.mlp.down_proj = shard_linear( |
| 264 | + layer.mlp.down_proj, "sharded-to-all", group=group |
| 265 | + ) |
| 266 | + layer.mlp.up_proj = shard_linear( |
| 267 | + layer.mlp.up_proj, "all-to-sharded", group=group |
| 268 | + ) |
| 269 | + |
| 270 | + gate_proj = self.model.gate_projections[i] |
| 271 | + heads_per_rank = gate_proj.num_heads // N |
| 272 | + start = rank * heads_per_rank |
| 273 | + end = start + heads_per_rank |
| 274 | + gate_proj.weight = gate_proj.weight[start:end, :] |
| 275 | + gate_proj.bias = gate_proj.bias[start:end] |
| 276 | + gate_proj.num_heads = heads_per_rank |
| 277 | + |
| 278 | + @property |
| 279 | + def layers(self): |
| 280 | + return self.model.layers |
| 281 | + |
| 282 | + def make_cache(self): |
| 283 | + return [KVCache() for _ in self.layers] + [ |
| 284 | + RotatingKVCache(max_size=self.args.loop_window_size) for _ in self.layers |
| 285 | + ] |
0 commit comments