diff --git a/mlx_lm/models/deepseek_v4.py b/mlx_lm/models/deepseek_v4.py index 12e4add48..67efa7831 100644 --- a/mlx_lm/models/deepseek_v4.py +++ b/mlx_lm/models/deepseek_v4.py @@ -11,7 +11,7 @@ from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import BatchRotatingKVCache, RotatingKVCache from .pipeline import PipelineMixin -from .switch_layers import SwitchGLU, _gather_sort, _scatter_unsort +from .switch_layers import SwitchGLU, _scatter_unsort @dataclass @@ -86,6 +86,25 @@ def _score_func(scores: mx.array, func: str) -> mx.array: raise ValueError(f"Unsupported DeepSeek-V4 scoring function: {func}") +@mx.compile +def _expert_select( + logits: mx.array, + e_score_correction_bias: mx.array, + top_k: int, + routed_scaling_factor: float, + norm_topk_prob: bool, + scoring_func: str, +) -> Tuple[mx.array, mx.array]: + scores = _score_func(logits, scoring_func) + biased = scores + e_score_correction_bias + inds = mx.argpartition(-biased, kth=top_k - 1, axis=-1)[..., :top_k] + weights = mx.take_along_axis(scores, inds, axis=-1) + if scoring_func != "softmax" and norm_topk_prob: + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + weights = weights * routed_scaling_factor + return inds, weights + + @mx.compile def _limited_swiglu(gate: mx.array, up: mx.array, limit: float) -> mx.array: if limit and limit > 0: @@ -106,28 +125,33 @@ def __call__(self, x, gate): class DeepseekV4SwitchGLU(SwitchGLU): sort_threshold = 8 - def __call__(self, x, indices) -> mx.array: + def __call__(self, x, indices, scores) -> mx.array: + out_shape = x.shape + route_shape = indices.shape x = mx.expand_dims(x, (-2, -3)) do_sort = indices.size >= self.sort_threshold - idx = indices inv_order = None if do_sort: - x, idx, inv_order = _gather_sort(x, indices) + flat_indices = indices.flatten() + order = mx.argsort(flat_indices) + inv_order = mx.argsort(order) + x = x.flatten(0, -3)[order // route_shape[-1]] + indices = flat_indices[order] + scores = scores.flatten()[order] if self.training: - idx = mx.stop_gradient(idx) - x_up = self.up_proj(x, idx, sorted_indices=do_sort) - x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) - x = self.down_proj( - self.activation(x_up, x_gate), - idx, - sorted_indices=do_sort, - ) + indices = mx.stop_gradient(indices) + x_up = self.up_proj(x, indices, sorted_indices=do_sort) + x_gate = self.gate_proj(x, indices, sorted_indices=do_sort) + x = self.activation(x_up, x_gate) + x = x * scores.astype(x.dtype)[..., None, None] + x = self.down_proj(x, indices, sorted_indices=do_sort) if do_sort: - x = _scatter_unsort(x, inv_order, indices.shape) + x = _scatter_unsort(x, inv_order, route_shape) - return x.squeeze(-2) + x = x.squeeze(-2) + return x.sum(axis=-2).astype(x.dtype).reshape(out_shape) class DeepseekV4RoPE(nn.Module): @@ -239,6 +263,7 @@ def _rope_full( return pe_out + def _apply_partial_rope( x: mx.array, rope: DeepseekV4RoPE, @@ -295,92 +320,75 @@ def _make_hc_split_sinkhorn_kernel(): source = """ uint idx = thread_position_in_grid.x; - constexpr int MIX = (2 + HC) * HC; - float epsv = static_cast(eps[0]); - - auto mix = mixes + idx * MIX; - auto pre_out = pre + idx * HC; - auto post_out = post + idx * HC; - auto comb_out = comb + idx * HC * HC; - - float pre_scale = static_cast(scale[0]); - float post_scale = static_cast(scale[1]); - float comb_scale = static_cast(scale[2]); - - for (int i = 0; i < HC; ++i) { - float z = static_cast(mix[i]) * pre_scale - + static_cast(base[i]); - pre_out[i] = 1.0f / (1.0f + metal::fast::exp(-z)) + epsv; - } - for (int i = 0; i < HC; ++i) { - int off = HC + i; - float z = static_cast(mix[off]) * post_scale - + static_cast(base[off]); - post_out[i] = 2.0f / (1.0f + metal::fast::exp(-z)); + constexpr int MIX = (2 + HC) * HC; + constexpr int BASE = 2 * HC; + + const device float* mix = (const device float*)mixes + idx * MIX; + device float* pre_out = (device float*)pre + idx * HC; + device float* post_out = (device float*)post + idx * HC; + device float* comb_out = (device float*)comb + idx * HC * HC; + + const float pre_scale = scale[0]; + const float post_scale = scale[1]; + const float comb_scale = scale[2]; + const float epsv = eps[0]; + + // Pre-sigmoid + { + float4 z = *(const device float4*)mix * pre_scale + + *(const device float4*)base; + *(device float4*)pre_out = 1.0f / (1.0f + metal::fast::exp(-z)) + epsv; } - float c[HC * HC]; - for (int i = 0; i < HC; ++i) { - float row_max = -INFINITY; - for (int j = 0; j < HC; ++j) { - int cidx = i * HC + j; - int off = 2 * HC + cidx; - float v = static_cast(mix[off]) * comb_scale - + static_cast(base[off]); - c[cidx] = v; - row_max = metal::max(row_max, v); - } - float row_sum = 0.0f; - for (int j = 0; j < HC; ++j) { - int cidx = i * HC + j; - float v = metal::fast::exp(c[cidx] - row_max); - c[cidx] = v; - row_sum += v; - } - float inv_sum = 1.0f / row_sum; - for (int j = 0; j < HC; ++j) { - int cidx = i * HC + j; - c[cidx] = c[cidx] * inv_sum + epsv; - } - } - - for (int j = 0; j < HC; ++j) { - float col_sum = 0.0f; - for (int i = 0; i < HC; ++i) { - col_sum += c[i * HC + j]; - } - float inv_denom = 1.0f / (col_sum + epsv); - for (int i = 0; i < HC; ++i) { - c[i * HC + j] *= inv_denom; - } + // Post-sigmoid + { + float4 z = *(const device float4*)(mix + HC) * post_scale + + *(const device float4*)(base + HC); + *(device float4*)post_out = 2.0f * 1.0f / (1.0f + metal::fast::exp(-z)); } + // Comb: four float4 loads — all independent, GPU issues in parallel + float4 v0 = *(const device float4*)(mix + BASE ) * comb_scale + *(const device float4*)(base + BASE ); + float4 v1 = *(const device float4*)(mix + BASE + 4) * comb_scale + *(const device float4*)(base + BASE + 4); + float4 v2 = *(const device float4*)(mix + BASE + 8) * comb_scale + *(const device float4*)(base + BASE + 8); + float4 v3 = *(const device float4*)(mix + BASE + 12) * comb_scale + *(const device float4*)(base + BASE + 12); + + // Per-row stable softmax: compute all maxes before any exp + float m0 = metal::max(metal::max(v0.x, v0.y), metal::max(v0.z, v0.w)); + float m1 = metal::max(metal::max(v1.x, v1.y), metal::max(v1.z, v1.w)); + float m2 = metal::max(metal::max(v2.x, v2.y), metal::max(v2.z, v2.w)); + float m3 = metal::max(metal::max(v3.x, v3.y), metal::max(v3.z, v3.w)); + + float4 e0 = metal::fast::exp(v0 - m0); + float4 e1 = metal::fast::exp(v1 - m1); + float4 e2 = metal::fast::exp(v2 - m2); + float4 e3 = metal::fast::exp(v3 - m3); + + // Explicit adds instead of dot(e, 1) — avoids unnecessary fmul + float4 r0 = e0 * 1.0f / (e0.x + e0.y + e0.z + e0.w) + epsv; + float4 r1 = e1 * 1.0f / (e1.x + e1.y + e1.z + e1.w) + epsv; + float4 r2 = e2 * 1.0f / (e2.x + e2.y + e2.z + e2.w) + epsv; + float4 r3 = e3 * 1.0f / (e3.x + e3.y + e3.z + e3.w) + epsv; + + // Initial column normalization + float4 col = 1.0f / (r0 + r1 + r2 + r3 + epsv); + r0 *= col; r1 *= col; r2 *= col; r3 *= col; + + // Sinkhorn iterations for (int iter = 1; iter < ITERS; ++iter) { - for (int i = 0; i < HC; ++i) { - float row_sum = 0.0f; - for (int j = 0; j < HC; ++j) { - row_sum += c[i * HC + j]; - } - float inv_denom = 1.0f / (row_sum + epsv); - for (int j = 0; j < HC; ++j) { - c[i * HC + j] *= inv_denom; - } - } - for (int j = 0; j < HC; ++j) { - float col_sum = 0.0f; - for (int i = 0; i < HC; ++i) { - col_sum += c[i * HC + j]; - } - float inv_denom = 1.0f / (col_sum + epsv); - for (int i = 0; i < HC; ++i) { - c[i * HC + j] *= inv_denom; - } - } + r0 *= 1.0f / (r0.x + r0.y + r0.z + r0.w + epsv); + r1 *= 1.0f / (r1.x + r1.y + r1.z + r1.w + epsv); + r2 *= 1.0f / (r2.x + r2.y + r2.z + r2.w + epsv); + r3 *= 1.0f / (r3.x + r3.y + r3.z + r3.w + epsv); + col = 1.0f / (r0 + r1 + r2 + r3 + epsv); + r0 *= col; r1 *= col; r2 *= col; r3 *= col; } - for (int i = 0; i < HC * HC; ++i) { - comb_out[i] = c[i]; - } + // Write comb output (four aligned 128-bit stores) + *(device float4*)(comb_out) = r0; + *(device float4*)(comb_out + 4) = r1; + *(device float4*)(comb_out + 8) = r2; + *(device float4*)(comb_out + 12) = r3; """ return mx.fast.metal_kernel( @@ -402,15 +410,16 @@ def hc_split_sinkhorn( sinkhorn_iters: int, eps: float, ) -> Tuple[mx.array, mx.array, mx.array]: - if _hc_split_sinkhorn_kernel is None: + if _hc_split_sinkhorn_kernel is None or hc_mult != 4: return _hc_split_sinkhorn_ops(mixes, scale, base, hc_mult, sinkhorn_iters, eps) if not isinstance(eps, mx.array): eps = mx.array([eps], dtype=mx.float32) + n_rows = mixes.size // ((2 + hc_mult) * hc_mult) return _hc_split_sinkhorn_kernel( inputs=[mixes, scale, base, eps], template=[("HC", hc_mult), ("ITERS", sinkhorn_iters)], - grid=(mixes.size // ((2 + hc_mult) * hc_mult), 1, 1), + grid=(n_rows, 1, 1), threadgroup=(256, 1, 1), output_shapes=[ (*mixes.shape[:-1], hc_mult), @@ -426,6 +435,159 @@ def _hc_collapse_op(pre: mx.array, x: mx.array) -> mx.array: return (pre[..., None] * x.astype(mx.float32)).sum(axis=2).astype(x.dtype) +def _make_hc_sinkhorn_collapse_kernel(): + """Fused sinkhorn + collapse: eliminates one dispatch per HC cycle. + + 1. BRANCHLESS SINKHORN: all 32 lanes in simd group 0 execute identical + instructions. Lanes >= HC use multiplicative mask (active=0) instead + of divergent branches — eliminates SIMD serialization. + 2. PARALLEL SINKHORN: lanes 0-3 each own one comb row. Column norm + via simd_sum() — free SIMD shuffle. + 3. NATIVE bfloat4 LOADS: single 64-bit load yields 4 bfloat16 values; + cast to float4 is a free hardware conversion. + 4. FMA CHAINS: collapse uses fused multiply-add for 3 of 4 terms. + """ + if mx.default_device() != mx.gpu or not mx.metal.is_available(): + return None + + source = """ + uint tid = thread_position_in_threadgroup.x; + uint row = threadgroup_position_in_grid.x; + uint lane = tid % 32; + uint sg = tid / 32; + + constexpr int MIX = (2 + HC) * HC; + constexpr int BASE_OFF = 2 * HC; + + const device float* mix = (const device float*)mixes + row * MIX; + device float* post_out = (device float*)post + row * HC; + device float* comb_out = (device float*)comb + row * HC * HC; + + threadgroup float pre_shared[HC]; + + // ================================================================ + // PHASE 1: Branchless sinkhorn on simd group 0 + // All 32 lanes execute identical instructions. Lanes >= HC + // compute on clamped indices but multiply by active=0, so they + // contribute zero to simd_sum. No divergent branches in the loop. + // ================================================================ + if (sg == 0) { + const float pre_scale = scale[0]; + const float post_scale = scale[1]; + const float comb_scale = scale[2]; + const float epsv = eps[0]; + + const float active = (lane < (uint)HC) ? 1.0f : 0.0f; + const uint llane = metal::min(lane, (uint)(HC - 1)); + + // Pre/post sigmoids: all lanes compute, only active lanes write + float pre_z = mix[llane] * pre_scale + base[llane]; + float post_z = mix[HC + llane] * post_scale + base[HC + llane]; + float pre_v = 1.0f / (1.0f + metal::fast::exp(-pre_z)) + epsv; + float post_v = 2.0f / (1.0f + metal::fast::exp(-post_z)); + + if (lane < (uint)HC) { + pre_shared[lane] = pre_v; + post_out[lane] = post_v; + } + + // Comb softmax: load + mask. Inactive lanes load row 0 (safe) + // but multiply by active=0 so they hold zeros. + float4 v = (*(const device float4*)(mix + BASE_OFF + llane * HC) + * comb_scale + + *(const device float4*)(base + BASE_OFF + llane * HC)) + * active; + + float row_max = metal::max(metal::max(v.x, v.y), + metal::max(v.z, v.w)); + float4 e = metal::fast::exp(v - row_max) * active; + float4 r = e * (1.0f / (e.x + e.y + e.z + e.w + epsv)) + + epsv * active; + + // Initial column normalization + float4 col_inv = 1.0f / (float4( + simd_sum(r.x), simd_sum(r.y), + simd_sum(r.z), simd_sum(r.w) + ) + epsv); + r *= col_inv; + + // Sinkhorn iterations: zero branches in the loop body + for (int iter = 1; iter < ITERS; ++iter) { + // Row norm + re-clamp inactive lanes + r *= (1.0f / (r.x + r.y + r.z + r.w + epsv)) * active; + + // Col norm via simd_sum + col_inv = 1.0f / (float4( + simd_sum(r.x), simd_sum(r.y), + simd_sum(r.z), simd_sum(r.w) + ) + epsv); + r *= col_inv; + } + + if (lane < (uint)HC) { + *(device float4*)(comb_out + lane * HC) = r; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ================================================================ + // PHASE 2: Collapse — all 256 threads, native bfloat4 vectorized + // ================================================================ + const float p0 = pre_shared[0]; + const float p1 = pre_shared[1]; + const float p2 = pre_shared[2]; + const float p3 = pre_shared[3]; + + const device bfloat16_t* x_row = (const device bfloat16_t*)x_in + + row * (HC * D); + device bfloat16_t* out_row = (device bfloat16_t*)collapsed + + row * D; + + // Native bfloat4 pointers: single 64-bit load per vector + using bf4 = vec; + const device bf4* x_row0 = (const device bf4*)(x_row + 0*D); + const device bf4* x_row1 = (const device bf4*)(x_row + 1*D); + const device bf4* x_row2 = (const device bf4*)(x_row + 2*D); + const device bf4* x_row3 = (const device bf4*)(x_row + 3*D); + device bf4* out4 = (device bf4*)out_row; + + constexpr uint D4 = (uint)D / 4; + + for (uint d4 = tid; d4 < D4; d4 += 256) { + float4 x0 = float4(x_row0[d4]); + float4 x1 = float4(x_row1[d4]); + float4 x2 = float4(x_row2[d4]); + float4 x3 = float4(x_row3[d4]); + + float4 result = fma(float4(p0), x0, + fma(float4(p1), x1, + fma(float4(p2), x2, float4(p3) * x3))); + + out4[d4] = bf4(result); + } + + // Scalar tail for D not divisible by 4 + #if (D % 4) != 0 + for (uint d = D4 * 4 + tid; d < (uint)D; d += 256) { + float val = p0*(float)x_row[0*D+d] + p1*(float)x_row[1*D+d] + + p2*(float)x_row[2*D+d] + p3*(float)x_row[3*D+d]; + out_row[d] = (bfloat16_t)val; + } + #endif + """ + + return mx.fast.metal_kernel( + name="deepseek_v4_hc_sinkhorn_collapse", + input_names=["mixes", "scale", "base", "eps", "x_in"], + output_names=["post", "comb", "collapsed"], + source=source, + ) + + +_hc_sinkhorn_collapse_kernel = _make_hc_sinkhorn_collapse_kernel() + + @mx.compile def _hc_expand_op( post: mx.array, @@ -443,6 +605,13 @@ def _rms_rsqrt(flat: mx.array, eps: float) -> mx.array: return mx.rsqrt((flat * flat).mean(axis=-1, keepdims=True) + eps) +@mx.compile +def _hc_mixes(flat: mx.array, fn_T: mx.array, norm_eps: float) -> mx.array: + """Fused RMS-rsqrt + matmul + scale into single compiled graph.""" + rsqrt = mx.rsqrt((flat * flat).mean(axis=-1, keepdims=True) + norm_eps) + return (flat @ fn_T) * rsqrt + + class HyperConnection(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -455,12 +624,18 @@ def __init__(self, config: ModelArgs): self.fn = mx.zeros((mix, self.hc_mult * config.hidden_size), dtype=mx.float32) self.base = mx.zeros((mix,), dtype=mx.float32) self.scale = mx.ones((3,), dtype=mx.float32) + self._fn_T = None def compute_weights(self, x: mx.array): B, L, H, D = x.shape flat = x.reshape(B, L, H * D).astype(mx.float32) - rsqrt = _rms_rsqrt(flat, self.norm_eps) - mixes = (flat @ self.fn.T) * rsqrt + if self._fn_T is None: + self._fn_T = self.fn.T + if self.training: + rsqrt = _rms_rsqrt(flat, self.norm_eps) + mixes = (flat @ self._fn_T) * rsqrt + else: + mixes = _hc_mixes(flat, self._fn_T, self.norm_eps) split_sinkhorn = _hc_split_sinkhorn_ops if self.training else hc_split_sinkhorn return split_sinkhorn( mixes, @@ -472,9 +647,42 @@ def compute_weights(self, x: mx.array): ) def collapse(self, x: mx.array): + if ( + not self.training + and _hc_sinkhorn_collapse_kernel is not None + and self.hc_mult == 4 + and x.dtype == mx.bfloat16 + ): + return self._fused_collapse(x) pre, post, comb = self.compute_weights(x) return _hc_collapse_op(pre, x), post, comb + def _fused_collapse(self, x: mx.array): + """Fused sinkhorn + collapse in a single Metal kernel dispatch.""" + B, L, H, D = x.shape + flat = x.reshape(B, L, H * D).astype(mx.float32) + if self._fn_T is None: + self._fn_T = self.fn.T + mixes = _hc_mixes(flat, self._fn_T, self.norm_eps) + + eps = self._hc_eps[0] + n_rows = B * L + x_flat = mx.contiguous(x.reshape(n_rows, H, D)) + + post, comb, collapsed = _hc_sinkhorn_collapse_kernel( + inputs=[mixes, self.scale, self.base, eps, x_flat], + template=[("HC", self.hc_mult), ("ITERS", self.sinkhorn_iters), ("D", D)], + grid=(n_rows * 256, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[ + (*mixes.shape[:-1], self.hc_mult), + (*mixes.shape[:-1], self.hc_mult, self.hc_mult), + (B, L, D), + ], + output_dtypes=[mx.float32, mx.float32, x.dtype], + ) + return collapsed, post, comb + def expand( self, block_out: mx.array, @@ -485,6 +693,24 @@ def expand( return _hc_expand_op(post, block_out, comb, residual) +@mx.compile +def _hyper_head_op( + x: mx.array, + fn: mx.array, + scale: mx.array, + base: mx.array, + norm_eps: float, + hc_eps: float, +) -> mx.array: + """Fused HyperHead: RMS-rsqrt + matmul + sigmoid + weighted sum.""" + B, L, H, D = x.shape + flat = x.reshape(B, L, H * D).astype(mx.float32) + rsqrt = mx.rsqrt((flat * flat).mean(axis=-1, keepdims=True) + norm_eps) + mixes = (flat @ fn.T) * rsqrt + pre = mx.sigmoid(mixes * scale[0] + base) + hc_eps + return (pre[..., None] * x.astype(mx.float32)).sum(axis=2).astype(x.dtype) + + class HyperHead(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -498,6 +724,10 @@ def __init__(self, config: ModelArgs): self.scale = mx.ones((1,), dtype=mx.float32) def __call__(self, x: mx.array): + if not self.training: + return _hyper_head_op( + x, self.fn, self.scale, self.base, self.norm_eps, self.hc_eps + ) B, L, H, D = x.shape flat = x.reshape(B, L, H * D).astype(mx.float32) rsqrt = _rms_rsqrt(flat, self.norm_eps) @@ -517,6 +747,7 @@ def __init__(self, config: ModelArgs, layer_idx: int): self.routed_scaling_factor = config.routed_scaling_factor self.norm_topk_prob = config.norm_topk_prob self.weight = mx.zeros((self.num_experts, self.hidden_dim)) + self._weight_T_f32 = None if self.hash: self.tid2eid = mx.zeros((config.vocab_size, self.top_k), dtype=mx.int32) else: @@ -526,23 +757,29 @@ def __init__(self, config: ModelArgs, layer_idx: int): def __call__(self, x: mx.array, input_ids: Optional[mx.array] = None): flat = x.reshape(-1, self.hidden_dim) - logits = flat.astype(mx.float32) @ self.weight.T.astype(mx.float32) - scores = _score_func(logits, self.scoring_func) + if self._weight_T_f32 is None: + self._weight_T_f32 = self.weight.T.astype(mx.float32) + logits = flat.astype(mx.float32) @ self._weight_T_f32 if self.hash: if input_ids is None: raise ValueError("DeepSeek-V4 hash routing requires input_ids.") + scores = _score_func(logits, self.scoring_func) inds = self.tid2eid[input_ids.reshape(-1)].astype(mx.int32) + weights = mx.take_along_axis(scores, inds, axis=-1) + if self.scoring_func != "softmax" and self.norm_topk_prob: + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + weights = weights * self.routed_scaling_factor else: - biased = scores + self.e_score_correction_bias - inds = mx.argpartition(-biased, kth=self.top_k - 1, axis=-1)[ - ..., : self.top_k - ] + inds, weights = _expert_select( + logits, + self.e_score_correction_bias, + self.top_k, + self.routed_scaling_factor, + self.norm_topk_prob, + self.scoring_func, + ) - weights = mx.take_along_axis(scores, inds, axis=-1) - if self.scoring_func != "softmax" and self.norm_topk_prob: - weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) - weights = weights * self.routed_scaling_factor route_shape = (*x.shape[:-1], self.top_k) inds = inds.reshape(route_shape) weights = weights.reshape(route_shape) @@ -592,8 +829,7 @@ def __call__(self, x: mx.array, input_ids: mx.array) -> mx.array: x = sum_gradients(self.sharding_group)(x) inds, scores = self.gate(x, input_ids) - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype).reshape(x.shape) + y = self.switch_mlp(x, inds, scores) y = y + self.shared_experts(x) if self.sharding_group is not None: @@ -1195,6 +1431,7 @@ def _merge_batch_state(values: List[Optional[mx.array]]): class Compressor(nn.Module): + def __init__(self, config: ModelArgs, compress_ratio: int, head_dim: int): super().__init__() self.compress_ratio = compress_ratio @@ -1209,10 +1446,12 @@ def __init__(self, config: ModelArgs, compress_ratio: int, head_dim: int): def _overlap_transform(self, x: mx.array, fill_value: float): B, W, R, _ = x.shape - out = mx.full((B, W, 2 * R, self.head_dim), fill_value, dtype=x.dtype) - out[:, :, R:] = x[:, :, :, self.head_dim :] - out[:, 1:, :R] = x[:, :-1, :, : self.head_dim] - return out + second_half = x[:, :, :, self.head_dim :] # (B, W, R, head_dim) + fill_row = mx.full((B, 1, R, self.head_dim), fill_value, dtype=x.dtype) + prev_first = mx.concatenate( + [fill_row, x[:, :-1, :, : self.head_dim]], axis=1 + ) # (B, W, R, head_dim) + return mx.concatenate([prev_first, second_half], axis=2) # (B, W, 2R, head_dim) def __call__( self, @@ -1311,6 +1550,7 @@ def __call__( return mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + class V4Attention(nn.Module): def __init__(self, config: ModelArgs, layer_idx: int): super().__init__() @@ -1345,6 +1585,7 @@ def __init__(self, config: ModelArgs, layer_idx: int): ) self.attn_sink = mx.zeros((self.n_heads,), dtype=mx.float32) self._q_l2_norm_weight = (mx.ones((self.head_dim,)),) + self._cached_dtype = None rope_theta = ( config.compress_rope_theta if self.compress_ratio else config.rope_theta @@ -1362,6 +1603,32 @@ def __init__(self, config: ModelArgs, layer_idx: int): if self.compress_ratio == 4: self.indexer = Indexer(config, self.compress_ratio) + def _ensure_cached(self, dtype): + if self._cached_dtype is not None and self._cached_dtype == dtype: + return + self._cached_dtype = dtype + self._attn_sink_cached = self.attn_sink.astype(dtype) + self._q_norm_weight_cached = self._q_l2_norm_weight[0].astype(dtype) + if isinstance(self.wo_a, nn.QuantizedLinear): + self._wo_a_weight = self.wo_a.weight.reshape( + self.o_groups, self.o_lora_rank, -1 + )[:, None] + self._wo_a_scales = self.wo_a.scales.reshape( + self.o_groups, self.o_lora_rank, -1 + )[:, None] + self._wo_a_biases = ( + None + if self.wo_a.biases is None + else self.wo_a.biases.reshape( + self.o_groups, self.o_lora_rank, -1 + )[:, None] + ) + else: + group_feat = (self.n_heads * self.head_dim) // self.o_groups + self._wo_a_weight_reshaped = self.wo_a.weight.reshape( + self.o_groups, self.o_lora_rank, group_feat + ) + def _grouped_output_projection(self, out: mx.array) -> mx.array: B, L = out.shape[:2] group_feat = (self.n_heads * self.head_dim) // self.o_groups @@ -1369,24 +1636,11 @@ def _grouped_output_projection(self, out: mx.array) -> mx.array: if isinstance(self.wo_a, nn.QuantizedLinear): out = out.transpose(2, 0, 1, 3) - weight = self.wo_a.weight.reshape(self.o_groups, self.o_lora_rank, -1)[ - :, None - ] - scales = self.wo_a.scales.reshape(self.o_groups, self.o_lora_rank, -1)[ - :, None - ] - biases = ( - None - if self.wo_a.biases is None - else self.wo_a.biases.reshape(self.o_groups, self.o_lora_rank, -1)[ - :, None - ] - ) out = mx.quantized_matmul( out, - weight, - scales=scales, - biases=biases, + self._wo_a_weight, + scales=self._wo_a_scales, + biases=self._wo_a_biases, transpose=True, group_size=self.wo_a.group_size, bits=self.wo_a.bits, @@ -1399,8 +1653,7 @@ def _grouped_output_projection(self, out: mx.array) -> mx.array: out = out + self.wo_a.bias return out - weight = self.wo_a.weight.reshape(self.o_groups, self.o_lora_rank, group_feat) - out = mx.einsum("bsgd,grd->bsgr", out, weight) + out = mx.einsum("bsgd,grd->bsgr", out, self._wo_a_weight_reshaped) out = out.reshape(B, L, self.o_groups * self.o_lora_rank) if "bias" in self.wo_a: out = out + self.wo_a.bias @@ -1422,8 +1675,9 @@ def __call__( offset = offset + 0 q_residual = self.q_norm(self.wq_a(x)) q = self.wq_b(q_residual).reshape(B, L, self.n_heads, self.head_dim) + self._ensure_cached(q.dtype) q = mx.fast.rms_norm( - q, self._q_l2_norm_weight[0].astype(q.dtype), self.config.rms_norm_eps + q, self._q_norm_weight_cached, self.config.rms_norm_eps ) q = q.transpose(0, 2, 1, 3) kv = self.kv_norm(self.wkv(x)).reshape(B, L, 1, self.head_dim) @@ -1442,50 +1696,51 @@ def __call__( if self.compress_ratio: v4_cache = cache if isinstance(cache, DeepseekV4Cache) else None pooled = self.compressor(x, self.compress_rope, v4_cache, offset) - lengths = ( - v4_cache.pooled_lengths("compressor_state") - if v4_cache is not None - else None - ) - use_indexer = hasattr(self, "indexer") and pooled.shape[1] > 0 - select_all = ( - use_indexer - and lengths is None - and pooled.shape[1] <= self.indexer.index_topk - ) - if select_all: - pooled = pooled[:, None] - pooled_bias = math.log(L) - elif use_indexer: - topk = self.indexer( - x, q_residual, self.compress_rope, self.rope, v4_cache, offset + if pooled.shape[1] > 0: + lengths = ( + v4_cache.pooled_lengths("compressor_state") + if v4_cache is not None + else None ) - if topk is not None: - if lengths is not None: - lengths = mx.array(lengths) - pooled_mask = (topk < lengths[:, None, None]).reshape( - B, 1, 1, -1 - ) - expanded = mx.broadcast_to( - pooled[:, None, None, :, :], - (B, 1, L, pooled.shape[1], self.head_dim), + use_indexer = hasattr(self, "indexer") + select_all = ( + use_indexer + and lengths is None + and pooled.shape[1] <= self.indexer.index_topk + ) + if select_all: + pooled = pooled[:, None] + pooled_bias = math.log(L) + elif use_indexer: + topk = self.indexer( + x, q_residual, self.compress_rope, self.rope, v4_cache, offset ) - idx = topk[:, None, :, :, None] - pooled = mx.take_along_axis( - expanded, - mx.broadcast_to(idx, idx.shape[:-1] + (self.head_dim,)), - axis=3, - ).reshape(B, 1, -1, self.head_dim) + if topk is not None: + if lengths is not None: + lengths = mx.array(lengths) + pooled_mask = (topk < lengths[:, None, None]).reshape( + B, 1, 1, -1 + ) + expanded = mx.broadcast_to( + pooled[:, None, None, :, :], + (B, 1, L, pooled.shape[1], self.head_dim), + ) + idx = topk[:, None, :, :, None] + pooled = mx.take_along_axis( + expanded, + mx.broadcast_to(idx, idx.shape[:-1] + (self.head_dim,)), + axis=3, + ).reshape(B, 1, -1, self.head_dim) + else: + pooled = pooled[:, None] else: + if lengths is not None: + lengths = mx.array(lengths) + pooled_mask = ( + mx.arange(pooled.shape[1]) < lengths[:, None] + ).reshape(B, 1, 1, -1) pooled = pooled[:, None] - else: - if lengths is not None: - lengths = mx.array(lengths) - pooled_mask = ( - mx.arange(pooled.shape[1]) < lengths[:, None] - ).reshape(B, 1, 1, -1) - pooled = pooled[:, None] - full_kv = mx.concatenate([full_kv, pooled], axis=2) + full_kv = mx.concatenate([full_kv, pooled], axis=2) if mask is not None and mask.shape[-1] > local_kv_len: mask = mask[..., -local_kv_len:] @@ -1517,7 +1772,6 @@ def __call__( mx.full(pad_shape, mx.finfo(mask.dtype).min, dtype=mask.dtype), ) mask = mx.concatenate([mask, pad], axis=-1) - out = scaled_dot_product_attention( q, full_kv, @@ -1525,7 +1779,7 @@ def __call__( cache=local_cache, scale=self.scale, mask=mask, - sinks=self.attn_sink.astype(q.dtype), + sinks=self._attn_sink_cached, ) out = _apply_partial_rope(out, self.rope, offset, inverse=True) out = out.transpose(0, 2, 1, 3).reshape(B, L, self.n_heads * self.head_dim)