From 6be59d4b781861558f689ea599e109e0894b362b Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Mon, 27 Apr 2026 13:23:06 -0400 Subject: [PATCH] fix: stream save_model to prevent OOM on large MoE models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When converting DeepSeek V4 Flash (256 experts × 43 layers) with -q, the process gets OOM-killed during save. The lazy computation graph from dequant → stack → quantize creates enormous BF16 intermediates that all materialize at once when saving. Build and save shards incrementally: pop weights from the dict as each shard is constructed, explicitly mx.eval before writing, then free. This bounds peak memory to ~one shard + one evaluation intermediate instead of the entire model's lazy graph. --- mlx_lm/utils.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 36114f05e..b2760d0d9 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -770,15 +770,29 @@ def save_model( save_path.mkdir(parents=True, exist_ok=True) weights = dict(tree_flatten(model.parameters())) - shards = make_shards(weights) - shards_count = len(shards) + + # Pre-compute shard boundaries using nbytes (no array evaluation needed) + max_file_size_bytes = MAX_FILE_SIZE_GB << 30 + shard_key_groups = [] + current_keys, current_size = [], 0 + total_size = 0 + for k, v in weights.items(): + if current_size + v.nbytes > max_file_size_bytes and current_keys: + shard_key_groups.append(current_keys) + current_keys, current_size = [], 0 + current_keys.append(k) + current_size += v.nbytes + total_size += v.nbytes + if current_keys: + shard_key_groups.append(current_keys) + + shards_count = len(shard_key_groups) shard_file_format = ( "model-{:05d}-of-{:05d}.safetensors" if shards_count > 1 else "model.safetensors" ) - total_size = sum(v.nbytes for v in weights.values()) index_data = { "metadata": { "total_size": total_size, @@ -789,23 +803,24 @@ def save_model( if donate_model: model.update(tree_map(lambda _: mx.array([]), model.parameters())) - # Write the weights and make sure no references are kept other than the - # necessary ones - weights.clear() - del weights + # Save shards one at a time, popping weights to free references as we go. + # This bounds peak memory to roughly one shard worth of evaluated arrays + # plus any co-evaluated arrays from shared computation graphs. + for i, keys in enumerate(shard_key_groups): + shard = {k: weights.pop(k) for k in keys} + + mx.eval(*shard.values()) - for i in range(len(shards)): - shard = shards[i] - shards[i] = None shard_name = shard_file_format.format(i + 1, shards_count) shard_path = save_path / shard_name - mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"}) - for weight_name in shard.keys(): + for weight_name in shard: index_data["weight_map"][weight_name] = shard_name del shard + del weights + index_data["weight_map"] = { k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) }