diff --git a/mlx_lm/models/deepseek_v4.py b/mlx_lm/models/deepseek_v4.py index 9a88ca79c..310c80294 100644 --- a/mlx_lm/models/deepseek_v4.py +++ b/mlx_lm/models/deepseek_v4.py @@ -648,16 +648,10 @@ def _hc_expand_op( return y.astype(block_out.dtype) -@mx.compile -def _rms_rsqrt(flat: mx.array, eps: float) -> mx.array: - return mx.rsqrt((flat * flat).mean(axis=-1, keepdims=True) + eps) - - @mx.compile def _hc_mixes(flat: mx.array, fn_T: mx.array, norm_eps: float) -> mx.array: - """Fused RMS-rsqrt + matmul + scale into single compiled graph.""" - rsqrt = mx.rsqrt((flat * flat).mean(axis=-1, keepdims=True) + norm_eps) - return (flat @ fn_T) * rsqrt + """Fused RMS-norm + matmul: rms_norm(flat) @ fn_T.""" + return mx.fast.rms_norm(flat, None, eps=norm_eps) @ fn_T class HyperConnection(nn.Module): @@ -679,11 +673,7 @@ def compute_weights(self, x: mx.array): flat = x.reshape(B, L, H * D).astype(mx.float32) if self._fn_T is None: self._fn_T = self.fn.T - if self.training: - rsqrt = _rms_rsqrt(flat, self.norm_eps) - mixes = (flat @ self._fn_T) * rsqrt - else: - mixes = _hc_mixes(flat, self._fn_T, self.norm_eps) + mixes = _hc_mixes(flat, self._fn_T, self.norm_eps) split_sinkhorn = _hc_split_sinkhorn_ops if self.training else hc_split_sinkhorn return split_sinkhorn( mixes, @@ -753,8 +743,7 @@ def _hyper_head_op( """Fused HyperHead: RMS-rsqrt + matmul + sigmoid + weighted sum.""" B, L, H, D = x.shape flat = x.reshape(B, L, H * D).astype(mx.float32) - rsqrt = mx.rsqrt((flat * flat).mean(axis=-1, keepdims=True) + norm_eps) - mixes = (flat @ fn.T) * rsqrt + mixes = mx.fast.rms_norm(flat, None, eps=norm_eps) @ fn.T pre = mx.sigmoid(mixes * scale[0] + base) + hc_eps return (pre[..., None] * x.astype(mx.float32)).sum(axis=2).astype(x.dtype) @@ -778,8 +767,7 @@ def __call__(self, x: mx.array): ) B, L, H, D = x.shape flat = x.reshape(B, L, H * D).astype(mx.float32) - rsqrt = _rms_rsqrt(flat, self.norm_eps) - mixes = (flat @ self.fn.T) * rsqrt + mixes = mx.fast.rms_norm(flat, None, eps=self.norm_eps) @ self.fn.T pre = mx.sigmoid(mixes * self.scale[0] + self.base) + self.hc_eps return (pre[..., None] * x.astype(mx.float32)).sum(axis=2).astype(x.dtype)