Skip to content

Commit 509f5ae

Browse files
author
Awni Hannun
authored
Fix sliding window batching (ml-explore#738)
1 parent 0f76343 commit 509f5ae

3 files changed

Lines changed: 8 additions & 8 deletions

File tree

mlx_lm/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,8 @@ def _process_prompts(self, prompts):
10711071
prompt_cache = _merge_caches(caches)
10721072

10731073
for c in prompt_cache:
1074-
c.prepare(lengths=lengths, right_padding=padding)
1074+
# subtract one from lengths since we don't process the last token during prefill
1075+
c.prepare(lengths=[l - 1 for l in lengths], right_padding=padding)
10751076

10761077
while inputs.shape[1] > 1:
10771078
n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1)
@@ -1096,6 +1097,7 @@ def _process_prompts(self, prompts):
10961097
y, logprobs = self._step(
10971098
inputs, prompt_cache, samplers, logits_processors, tokens
10981099
)
1100+
10991101
mx.async_eval(y, logprobs)
11001102

11011103
return Batch(

mlx_lm/models/cache.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,12 +1157,10 @@ def extract(self, idx):
11571157
cache.keys = mx.roll(cache.keys, -self._idx, axis=2)
11581158
cache.values = mx.roll(cache.values, -self._idx, axis=2)
11591159
cache._idx = self.max_size
1160-
if padding > 0:
1161-
cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx])
1162-
cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx])
1160+
cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx])
1161+
cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx])
11631162
cache.offset = offset
11641163
cache._idx = cache.keys.shape[2]
1165-
11661164
return cache
11671165

11681166
@classmethod
@@ -1185,8 +1183,8 @@ def merge(cls, caches):
11851183
keys = mx.zeros((B, H, max_length, Dk), dtype=dt)
11861184
values = mx.zeros((B, H, max_length, Dv), dtype=dt)
11871185
for i, (p, c) in enumerate(zip(padding, caches)):
1188-
keys[i : i + 1, :, p : p + c.offset] = c._temporal_order(c.keys)
1189-
values[i : i + 1, :, p : p + c.offset] = c._temporal_order(c.values)
1186+
keys[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.keys)
1187+
values[i : i + 1, :, p : p + c._idx] = c._temporal_order(c.values)
11901188

11911189
cache = cls(caches[0].max_size, padding)
11921190
cache.keys = keys

tests/test_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def test_batch_continued_generation(self):
473473
self.model,
474474
stop_tokens=self.tokenizer.eos_token_ids,
475475
max_tokens=10,
476-
prefill_batch_size=1,
476+
prefill_batch_size=4,
477477
prefill_step_size=8,
478478
completion_batch_size=2,
479479
)

0 commit comments

Comments
 (0)