Skip to content

Commit 44d12e5

Browse files
kernelpoolAwni Hannun
andauthored
Fix batch generation for IQuestLoopCoder model (ml-explore#748)
* Fix batch generation * fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
1 parent 7a86c12 commit 44d12e5

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

mlx_lm/models/iquestloopcoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def __call__(
167167
if cache is None:
168168
cache = [None] * (2 * len(self.layers))
169169

170-
offset = cache[0].offset if cache[0] is not None else 0
171170
mask = create_attention_mask(h, cache[0])
172171
window_mask = create_attention_mask(
173172
h, cache[len(self.layers)], window_size=self.loop_window_size
@@ -176,6 +175,7 @@ def __call__(
176175
loop1_kv = []
177176
for layer, c in zip(self.layers, cache):
178177
h_norm = layer.input_layernorm(h)
178+
offset = c.offset if c is not None else 0
179179
q1, k1, v1 = layer.self_attn.get_qkv(h_norm, offset)
180180

181181
if c is not None:
@@ -192,6 +192,7 @@ def __call__(
192192
self.layers, self.gate_projections, cache[len(self.layers) :], loop1_kv
193193
):
194194
h_norm = layer.input_layernorm(h)
195+
offset = c.offset if c is not None else 0
195196
q2, k2, v2 = layer.self_attn.get_qkv(h_norm, offset)
196197
gate = gate_proj(q2)
197198
attn_global = layer.self_attn.attention(q2, k1, v1, mask, cache=c)

0 commit comments

Comments
 (0)