Skip to content

Commit 0f76343

Browse files
kernelpoolAwni Hannun
andauthored
Add IQuest Coder V1 Loop variant (ml-explore#716)
* Add IQuest Coder V1 Loop variant * Minor tweaks * Fix cache population * Clean up nested for loop * Simplify * Bug fix in prefill * Address feedback * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
1 parent 298b67c commit 0f76343

2 files changed

Lines changed: 308 additions & 0 deletions

File tree

mlx_lm/models/iquestloopcoder.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# Copyright © 2026 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from functools import partial
5+
from typing import Any, Dict, List, Optional, Tuple, Union
6+
7+
import mlx.core as mx
8+
import mlx.nn as nn
9+
from mlx.nn.layers.distributed import shard_linear
10+
11+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
12+
from .cache import KVCache, RotatingKVCache
13+
from .rope_utils import initialize_rope
14+
15+
16+
@partial(mx.compile, shapeless=True)
17+
def _compute_gate(query: mx.array, weight: mx.array, bias: mx.array) -> mx.array:
18+
gate_logits = query @ weight[:, None, :].swapaxes(-1, -2)
19+
gate_logits = gate_logits + bias[..., None, None]
20+
return mx.sigmoid(gate_logits)
21+
22+
23+
@partial(mx.compile, shapeless=True)
24+
def _silu_mul(gate: mx.array, up: mx.array) -> mx.array:
25+
return nn.silu(gate) * up
26+
27+
28+
@partial(mx.compile, shapeless=True)
29+
def _mix_attention(
30+
gate: mx.array, attn_global: mx.array, attn_local: mx.array
31+
) -> mx.array:
32+
return gate * attn_global + (1 - gate) * attn_local
33+
34+
35+
@dataclass
36+
class ModelArgs(BaseModelArgs):
37+
model_type: str
38+
hidden_size: int
39+
num_hidden_layers: int
40+
intermediate_size: int
41+
num_attention_heads: int
42+
rms_norm_eps: float
43+
vocab_size: int
44+
head_dim: int
45+
num_key_value_heads: int
46+
max_position_embeddings: int = 131072
47+
attention_bias: bool = False
48+
mlp_bias: bool = False
49+
rope_theta: float = 500000.0
50+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
51+
tie_word_embeddings: bool = False
52+
loop_num: int = 2
53+
loop_window_size: int = 64
54+
55+
56+
class LoopGateProjection(nn.Module):
57+
def __init__(self, num_heads: int, head_dim: int):
58+
super().__init__()
59+
self.num_heads = num_heads
60+
self.head_dim = head_dim
61+
self.weight = mx.zeros((num_heads, head_dim))
62+
self.bias = mx.zeros((num_heads,))
63+
64+
def __call__(self, query: mx.array) -> mx.array:
65+
return _compute_gate(query, self.weight, self.bias)
66+
67+
68+
class Attention(nn.Module):
69+
def __init__(self, args: ModelArgs):
70+
super().__init__()
71+
72+
dim = args.hidden_size
73+
self.n_heads = n_heads = args.num_attention_heads
74+
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
75+
self.head_dim = head_dim = args.head_dim
76+
self.scale = head_dim**-0.5
77+
78+
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
79+
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
80+
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
81+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
82+
83+
self.rope = initialize_rope(
84+
head_dim,
85+
args.rope_theta,
86+
traditional=False,
87+
scaling_config=args.rope_scaling,
88+
max_position_embeddings=args.max_position_embeddings,
89+
)
90+
91+
def get_qkv(
92+
self, x: mx.array, offset: int = 0
93+
) -> Tuple[mx.array, mx.array, mx.array]:
94+
B, L, _ = x.shape
95+
queries = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
96+
keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
97+
values = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
98+
99+
queries = self.rope(queries, offset=offset)
100+
keys = self.rope(keys, offset=offset)
101+
102+
return queries, keys, values
103+
104+
def attention(
105+
self,
106+
queries: mx.array,
107+
keys: mx.array,
108+
values: mx.array,
109+
mask: Optional[mx.array] = None,
110+
cache: Optional[Any] = None,
111+
) -> mx.array:
112+
return scaled_dot_product_attention(
113+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
114+
)
115+
116+
117+
class MLP(nn.Module):
118+
def __init__(self, args: ModelArgs):
119+
super().__init__()
120+
dim = args.hidden_size
121+
hidden_dim = args.intermediate_size
122+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
123+
self.down_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
124+
self.up_proj = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
125+
126+
def __call__(self, x: mx.array) -> mx.array:
127+
return self.down_proj(_silu_mul(self.gate_proj(x), self.up_proj(x)))
128+
129+
130+
class TransformerBlock(nn.Module):
131+
def __init__(self, args: ModelArgs):
132+
super().__init__()
133+
self.self_attn = Attention(args)
134+
self.mlp = MLP(args)
135+
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
136+
self.post_attention_layernorm = nn.RMSNorm(
137+
args.hidden_size, eps=args.rms_norm_eps
138+
)
139+
140+
141+
class IQuestLoopCoderModel(nn.Module):
142+
def __init__(self, args: ModelArgs):
143+
super().__init__()
144+
assert args.loop_num == 2, f"Only loop_num=2 is supported, got {args.loop_num}"
145+
self.args = args
146+
self.vocab_size = args.vocab_size
147+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
148+
self.layers = [
149+
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
150+
]
151+
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
152+
self.gate_projections = [
153+
LoopGateProjection(args.num_attention_heads, args.head_dim)
154+
for _ in range(args.num_hidden_layers)
155+
]
156+
self.loop_num = args.loop_num
157+
self.loop_window_size = args.loop_window_size
158+
159+
def __call__(
160+
self,
161+
inputs: mx.array,
162+
cache: Optional[List[Any]] = None,
163+
):
164+
B, L = inputs.shape[:2]
165+
h = self.embed_tokens(inputs)
166+
167+
if cache is None:
168+
cache = [None] * (2 * len(self.layers))
169+
170+
offset = cache[0].offset if cache[0] is not None else 0
171+
mask = create_attention_mask(h, cache[0])
172+
window_mask = create_attention_mask(
173+
h, cache[len(self.layers)], window_size=self.loop_window_size
174+
)
175+
176+
loop1_kv = []
177+
for layer, c in zip(self.layers, cache):
178+
h_norm = layer.input_layernorm(h)
179+
q1, k1, v1 = layer.self_attn.get_qkv(h_norm, offset)
180+
181+
if c is not None:
182+
k1, v1 = c.update_and_fetch(k1, v1)
183+
loop1_kv.append((k1, v1))
184+
185+
out = layer.self_attn.attention(q1, k1, v1, mask, cache=c)
186+
r = layer.self_attn.o_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
187+
h = h + r
188+
r = layer.mlp(layer.post_attention_layernorm(h))
189+
h = h + r
190+
191+
for layer, gate_proj, c, (k1, v1) in zip(
192+
self.layers, self.gate_projections, cache[len(self.layers) :], loop1_kv
193+
):
194+
h_norm = layer.input_layernorm(h)
195+
q2, k2, v2 = layer.self_attn.get_qkv(h_norm, offset)
196+
gate = gate_proj(q2)
197+
attn_global = layer.self_attn.attention(q2, k1, v1, mask, cache=c)
198+
199+
if c is not None:
200+
k2, v2 = c.update_and_fetch(k2, v2)
201+
attn_local = layer.self_attn.attention(
202+
q2,
203+
k2,
204+
v2,
205+
window_mask,
206+
cache=c,
207+
)
208+
209+
mixed = _mix_attention(gate, attn_global, attn_local)
210+
r = layer.self_attn.o_proj(mixed.transpose(0, 2, 1, 3).reshape(B, L, -1))
211+
h = h + r
212+
r = layer.mlp(layer.post_attention_layernorm(h))
213+
h = h + r
214+
215+
return self.norm(h)
216+
217+
218+
class Model(nn.Module):
219+
def __init__(self, args: ModelArgs):
220+
super().__init__()
221+
self.args = args
222+
self.model_type = args.model_type
223+
self.model = IQuestLoopCoderModel(args)
224+
if not args.tie_word_embeddings:
225+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
226+
227+
def __call__(
228+
self,
229+
inputs: mx.array,
230+
cache=None,
231+
):
232+
out = self.model(inputs, cache)
233+
if self.args.tie_word_embeddings:
234+
out = self.model.embed_tokens.as_linear(out)
235+
else:
236+
out = self.lm_head(out)
237+
return out
238+
239+
def shard(self, group: Optional[mx.distributed.Group] = None):
240+
group = group or mx.distributed.init()
241+
N = group.size()
242+
rank = group.rank()
243+
244+
for i, layer in enumerate(self.model.layers):
245+
layer.self_attn.q_proj = shard_linear(
246+
layer.self_attn.q_proj, "all-to-sharded", group=group
247+
)
248+
layer.self_attn.k_proj = shard_linear(
249+
layer.self_attn.k_proj, "all-to-sharded", group=group
250+
)
251+
layer.self_attn.v_proj = shard_linear(
252+
layer.self_attn.v_proj, "all-to-sharded", group=group
253+
)
254+
layer.self_attn.o_proj = shard_linear(
255+
layer.self_attn.o_proj, "sharded-to-all", group=group
256+
)
257+
layer.self_attn.n_heads //= N
258+
layer.self_attn.n_kv_heads //= N
259+
260+
layer.mlp.gate_proj = shard_linear(
261+
layer.mlp.gate_proj, "all-to-sharded", group=group
262+
)
263+
layer.mlp.down_proj = shard_linear(
264+
layer.mlp.down_proj, "sharded-to-all", group=group
265+
)
266+
layer.mlp.up_proj = shard_linear(
267+
layer.mlp.up_proj, "all-to-sharded", group=group
268+
)
269+
270+
gate_proj = self.model.gate_projections[i]
271+
heads_per_rank = gate_proj.num_heads // N
272+
start = rank * heads_per_rank
273+
end = start + heads_per_rank
274+
gate_proj.weight = gate_proj.weight[start:end, :]
275+
gate_proj.bias = gate_proj.bias[start:end]
276+
gate_proj.num_heads = heads_per_rank
277+
278+
@property
279+
def layers(self):
280+
return self.model.layers
281+
282+
def make_cache(self):
283+
return [KVCache() for _ in self.layers] + [
284+
RotatingKVCache(max_size=self.args.loop_window_size) for _ in self.layers
285+
]

tests/test_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,29 @@ def test_internlm3(self):
12521252
model, args.model_type, args.vocab_size, args.num_hidden_layers
12531253
)
12541254

1255+
def test_iquestloopcoder(self):
1256+
from mlx_lm.models import iquestloopcoder
1257+
1258+
args = iquestloopcoder.ModelArgs(
1259+
model_type="iquestloopcoder",
1260+
hidden_size=256,
1261+
num_hidden_layers=2,
1262+
intermediate_size=512,
1263+
num_attention_heads=4,
1264+
num_key_value_heads=2,
1265+
rms_norm_eps=1e-5,
1266+
head_dim=32,
1267+
vocab_size=1000,
1268+
rope_theta=500000.0,
1269+
tie_word_embeddings=False,
1270+
loop_num=2,
1271+
loop_window_size=32,
1272+
)
1273+
model = iquestloopcoder.Model(args)
1274+
self.model_test_runner(
1275+
model, args.model_type, args.vocab_size, args.num_hidden_layers
1276+
)
1277+
12551278
def test_smollm3(self):
12561279
from mlx_lm.models import smollm3
12571280

0 commit comments

Comments
 (0)