Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,15 +770,29 @@ def save_model(
save_path.mkdir(parents=True, exist_ok=True)

weights = dict(tree_flatten(model.parameters()))
shards = make_shards(weights)
shards_count = len(shards)

# Pre-compute shard boundaries using nbytes (no array evaluation needed)
max_file_size_bytes = MAX_FILE_SIZE_GB << 30
shard_key_groups = []
current_keys, current_size = [], 0
total_size = 0
for k, v in weights.items():
if current_size + v.nbytes > max_file_size_bytes and current_keys:
shard_key_groups.append(current_keys)
current_keys, current_size = [], 0
current_keys.append(k)
current_size += v.nbytes
total_size += v.nbytes
if current_keys:
shard_key_groups.append(current_keys)

shards_count = len(shard_key_groups)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)

total_size = sum(v.nbytes for v in weights.values())
index_data = {
"metadata": {
"total_size": total_size,
Expand All @@ -789,23 +803,24 @@ def save_model(
if donate_model:
model.update(tree_map(lambda _: mx.array([]), model.parameters()))

# Write the weights and make sure no references are kept other than the
# necessary ones
weights.clear()
del weights
# Save shards one at a time, popping weights to free references as we go.
# This bounds peak memory to roughly one shard worth of evaluated arrays
# plus any co-evaluated arrays from shared computation graphs.
for i, keys in enumerate(shard_key_groups):
shard = {k: weights.pop(k) for k in keys}

mx.eval(*shard.values())

for i in range(len(shards)):
shard = shards[i]
shards[i] = None
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name

mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})

for weight_name in shard.keys():
for weight_name in shard:
index_data["weight_map"][weight_name] = shard_name
del shard

del weights

index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
Expand Down