Skip to content

Commit bcf6306

Browse files
authored
Fix save/load of CacheList (ml-explore#886)
1 parent 1974376 commit bcf6306

2 files changed

Lines changed: 58 additions & 7 deletions

File tree

mlx_lm/models/cache.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -742,16 +742,24 @@ def trim(self, n):
742742

743743
@property
744744
def state(self):
745-
return [s for c in self.caches for s in c.state]
745+
return [c.state for c in self.caches]
746746

747747
@state.setter
748748
def state(self, v):
749-
state_lens = [len(c.state) for c in self.caches]
750-
start = 0
751-
for c in self.caches:
752-
l = len(c.state)
753-
c.state = v[start : start + l]
754-
start += l
749+
for c, s in zip(self.caches, v):
750+
c.state = s
751+
752+
@property
753+
def meta_state(self):
754+
return (
755+
[type(c).__name__ for c in self.caches],
756+
[c.meta_state for c in self.caches],
757+
)
758+
759+
@meta_state.setter
760+
def meta_state(self, v):
761+
for c, m in zip(self.caches, v[1]):
762+
c.meta_state = m
755763

756764
def filter(self, batch_indices):
757765
"""
@@ -793,6 +801,14 @@ def size(self):
793801
def empty(self):
794802
return self.caches[0].empty()
795803

804+
@classmethod
805+
def from_state(cls, state, meta_state):
806+
obj = cls.__new__(cls)
807+
obj.caches = [
808+
globals()[c].from_state(s, m) for s, c, m in zip(state, *meta_state)
809+
]
810+
return obj
811+
796812

797813
def dynamic_roll(x, shifts, axis):
798814
n = x.shape[axis]

tests/test_prompt_cache.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,41 @@ def test_save_load_mixed_cache(self):
132132
self.assertTrue(mx.array_equal(k, lk))
133133
self.assertTrue(mx.array_equal(v, lv))
134134

135+
def test_save_load_cache_list(self):
136+
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
137+
138+
cache = [
139+
ArraysCache(size=2),
140+
KVCache(),
141+
RotatingKVCache(8),
142+
ArraysCache(size=2),
143+
ChunkedKVCache(256),
144+
]
145+
for c in cache:
146+
if isinstance(c, ArraysCache):
147+
c[0] = mx.random.uniform(shape=(4, 4, 4))
148+
c[1] = mx.random.uniform(shape=(4, 4, 4))
149+
else:
150+
x = mx.random.uniform(shape=(4, 4, 7, 4))
151+
y = mx.random.uniform(shape=(4, 4, 7, 4))
152+
c.update_and_fetch(x, y)
153+
cache = [CacheList(*cache)]
154+
155+
save_prompt_cache(cache_file, cache)
156+
loaded_cache = load_prompt_cache(cache_file)
157+
for c, lc in zip(cache[0].caches, loaded_cache[0].caches):
158+
if isinstance(c, ArraysCache):
159+
self.assertTrue(mx.array_equal(c[0], lc[0]))
160+
self.assertTrue(mx.array_equal(c[1], lc[1]))
161+
else:
162+
x = mx.random.uniform(shape=(4, 4, 1, 4))
163+
y = mx.random.uniform(shape=(4, 4, 1, 4))
164+
k, v = c.update_and_fetch(x, y)
165+
lk, lv = lc.update_and_fetch(x, y)
166+
self.assertEqual(c.offset, lc.offset)
167+
self.assertTrue(mx.array_equal(k, lk))
168+
self.assertTrue(mx.array_equal(v, lv))
169+
135170
def test_save_load_arrays_cache(self):
136171
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
137172

0 commit comments

Comments
 (0)