diff --git a/mlx_lm/models/deepseek_v4.py b/mlx_lm/models/deepseek_v4.py index acea50f6b..c41731bf7 100644 --- a/mlx_lm/models/deepseek_v4.py +++ b/mlx_lm/models/deepseek_v4.py @@ -313,7 +313,15 @@ def _sparse_pooled_attention( q_scaled = q * scale local_scores = q_scaled @ local_kv.swapaxes(-1, -2) local_scores = _apply_score_mask(local_scores, local_mask) - pooled_scores = (q_scaled[:, :, :, None] * pooled).sum(axis=-1) + + # Pooled scores via matmul instead of broadcast multiply + sum. + # The element-wise path creates a (B, H, L, topk, D) intermediate which + # at 4k context with H=64, topk=512, D=512 is ~137 GB. + # Matmul (B*L, H, D) @ (B*L, D, topk) → (B*L, H, topk) uses ~0.25 GB. + pooled_sq = pooled.squeeze(1) # (B, L, topk, D) + q_bl = q_scaled.transpose(0, 2, 1, 3) # (B, L, H, D) + pooled_scores = q_bl @ pooled_sq.swapaxes(-1, -2) # (B, L, H, topk) + pooled_scores = pooled_scores.transpose(0, 2, 1, 3) # (B, H, L, topk) pooled_scores = _apply_score_mask(pooled_scores, pooled_mask) scores = mx.concatenate([local_scores, pooled_scores], axis=-1) @@ -329,7 +337,9 @@ def _sparse_pooled_attention( pooled_weights = weights[..., sink_offset + local_len :] out = local_weights @ local_kv - out = out + (pooled_weights[..., None] * pooled).sum(axis=-2) + # Same matmul trick for weighted sum: (B*L, H, topk) @ (B*L, topk, D) + pw_bl = pooled_weights.transpose(0, 2, 1, 3) # (B, L, H, topk) + out = out + (pw_bl @ pooled_sq).transpose(0, 2, 1, 3) # (B, H, L, D) return out.astype(q.dtype)