diff --git a/mlx_lm/models/bitlinear_layers.py b/mlx_lm/models/bitlinear_layers.py index b5b562d3f..7af496ff7 100644 --- a/mlx_lm/models/bitlinear_layers.py +++ b/mlx_lm/models/bitlinear_layers.py @@ -1,12 +1,16 @@ +from typing import Any, Optional + import mlx.core as mx import mlx.nn as nn +from .base import BaseModelArgs, scaled_dot_product_attention +from .rope_utils import initialize_rope class BitLinear(nn.Module): """ BitLinear module with memory-efficient weight handling. """ - def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, invert_weight_scales = False): + def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, invert_weight_scales = False, fused_shapes = None): super().__init__() self.dtype = dtype self.in_features = in_features @@ -18,7 +22,14 @@ def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, inver self.weight = mx.zeros((packed_out_features, in_features), dtype=mx.uint8) self.invert_weight_scales = invert_weight_scales - self.weight_scale = mx.array([1.0], dtype=dtype) + self.fused_shapes = fused_shapes + + if fused_shapes is None: + self.weight_scale = mx.array([1.0], dtype=dtype) + else: + self.weight_scale = mx.array([1.0] * (len(fused_shapes) + 1), dtype=dtype) + + self.fused_layers = fused_shapes is not None if bias: self.bias = mx.zeros((out_features,), dtype=dtype) @@ -47,14 +58,14 @@ def bitlinear_kernel(self, x, packed_weights): // Calculate packed dimensions uint packed_rows = out_features / 4; // Each packed row contains 4 output rows + // Determine which packed row and which bit position within that packed value + uint which_slice = out_idx / packed_rows; // Which of the 4 slices (0, 1, 2, 3) + uint row_in_slice = out_idx % packed_rows; // Which row within that slice + for (uint i = 0; i < in_features; i++) { // Get input value float x_val = x[batch_idx * in_features + i]; - // Determine which packed row and which bit position within that packed value - uint which_slice = out_idx / packed_rows; // Which of the 4 slices (0, 1, 2, 3) - uint row_in_slice = out_idx % packed_rows; // Which row within that slice - // Get the packed weight value uint packed_idx = row_in_slice * in_features + i; uint8_t packed_val = packed_weights[packed_idx]; @@ -69,11 +80,23 @@ def bitlinear_kernel(self, x, packed_weights): sum += x_val * weight_val; } + uint weight_layer_idx = 0; // Single layer case without any fusing + // Apply weight scaling by diving them or multiplying them + if (fused_length > 0) { + // determine the index by checking the interval which weight_layer_idx belongs to + for (uint i = 0; i < fused_length; i++) { + if (out_idx < fused_shapes[i]) { + weight_layer_idx = i; + break; + } + } + } + if (invert_weight_scales) { - out[tid] = sum / weight_scale[0]; + out[tid] = sum / weight_scale[weight_layer_idx]; } else { - out[tid] = sum * weight_scale[0]; + out[tid] = sum * weight_scale[weight_layer_idx]; } """ @@ -94,16 +117,22 @@ def bitlinear_kernel(self, x, packed_weights): if self._compiled_kernel is None: self._compiled_kernel = mx.fast.metal_kernel( name="bitlinear_matmul", - input_names=["x", "packed_weights", "weight_scale", "invert_weight_scales"], + input_names=["x", "packed_weights", "weight_scale", "invert_weight_scales", "fused_shapes", "fused_length"], output_names=["out"], source=source, ) + if self.fused_layers: + self.fused_shapes.append(out_features) + inputs = [x_flattened.astype(self.dtype), packed_weights, self.weight_scale, self.invert_weight_scales, mx.array(self.fused_shapes), len(self.fused_shapes)] + else: + inputs = [x_flattened.astype(self.dtype), packed_weights, self.weight_scale, self.invert_weight_scales, mx.array([]), 0] + outputs = self._compiled_kernel( - inputs=[x_flattened.astype(self.dtype), packed_weights, self.weight_scale, self.invert_weight_scales], + inputs=inputs, template=[("batch_size", total_batch_elements), ("in_features", in_features), ("out_features", out_features)], grid=(total_batch_elements * out_features, 1, 1), - threadgroup=(min(256, total_batch_elements * out_features), 1, 1), + threadgroup=(min(32, total_batch_elements * out_features), 1, 1), output_shapes=[(total_batch_elements, out_features)], output_dtypes=[self.dtype], ) @@ -130,4 +159,78 @@ def __call__(self, x): y = mx.add(y, self.bias) return y.astype(org_dtype) - \ No newline at end of file + + +class BitLinearFusedAttention(nn.Module): + def __init__(self, args, invert_weight_scales: bool = True): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + query_pos = n_heads * head_dim + + self.qkv_proj = BitLinear( + dim, + (n_heads + 2 * n_kv_heads) * head_dim, + bias=attention_bias, + fused_shapes=[query_pos, query_pos + self.n_kv_heads * self.head_dim], + invert_weight_scales=invert_weight_scales + ) + self.o_proj = BitLinear( + n_heads * head_dim, + dim, + bias=attention_bias, + invert_weight_scales=invert_weight_scales + ) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + qkv = self.qkv_proj(x) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) diff --git a/mlx_lm/models/bitnet.py b/mlx_lm/models/bitnet.py index a0ada396e..51e67f2e9 100644 --- a/mlx_lm/models/bitnet.py +++ b/mlx_lm/models/bitnet.py @@ -48,11 +48,14 @@ def __init__(self, args: ModelArgs): self.scale = head_dim**-0.5 attention_bias = getattr(args, "attention_bias", False) + query_pos = n_heads * head_dim + # Single QKV projection self.qkv_proj = BitLinear( dim, (n_heads + 2 * n_kv_heads) * head_dim, - bias=attention_bias + bias=attention_bias, + fused_shapes=[query_pos, query_pos + self.n_kv_heads * self.head_dim], ) self.o_proj = BitLinear(n_heads * head_dim, dim, bias=attention_bias) @@ -244,8 +247,8 @@ def sanitize(self, weights): q_scale = weights[f"{prefix}self_attn.q_proj.weight_scale"] k_scale = weights[f"{prefix}self_attn.k_proj.weight_scale"] v_scale = weights[f"{prefix}self_attn.v_proj.weight_scale"] - qkv_scale = mx.sqrt((q_scale**2 + k_scale**2 + v_scale**2) / 3) # Root mean square - sanitized[f"{prefix}self_attn.qkv_proj.weight_scale"] = qkv_scale + # qkv_scale = mx.sqrt((q_scale**2 + k_scale**2 + v_scale**2) / 3) # Root mean square + sanitized[f"{prefix}self_attn.qkv_proj.weight_scale"] = mx.concatenate([q_scale, k_scale, v_scale], axis=0) # Handle biases if they exist if f"{prefix}self_attn.q_proj.bias" in weights: diff --git a/mlx_lm/models/llama.py b/mlx_lm/models/llama.py index 39f550c1e..469607375 100644 --- a/mlx_lm/models/llama.py +++ b/mlx_lm/models/llama.py @@ -78,6 +78,7 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + if cache is not None: queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) diff --git a/mlx_lm/quant/utils.py b/mlx_lm/quant/utils.py index 4acbad783..cf91830bf 100644 --- a/mlx_lm/quant/utils.py +++ b/mlx_lm/quant/utils.py @@ -32,12 +32,11 @@ def load_data(tokenizer, num_samples: int, sequence_length: int) -> mx.array: segments = segments[:num_samples] return tokens[segments] -def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_not_convert=None): +def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_not_convert=None, fuse_qkv=False): quantize_layers = [] for name, module in model.named_modules(): if modules_to_not_convert is None: modules_to_not_convert = [] - # Replace nn.Linear layers, but skip 'lm_head' if name not in modules_to_not_convert and isinstance(module, nn.Linear): old_weight = module.weight @@ -47,7 +46,66 @@ def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_ new_layer = QUANT_LINEAR_MAPPING[quant_method](in_features, out_features, bias=bias, invert_weight_scales=True) # Replace the layer in the model - quantize_layers.append((name, new_layer)) + if fuse_qkv and not any(name.endswith(suffix) for suffix in ["q_proj", "k_proj", "v_proj", "o_proj"]): + quantize_layers.append((name, new_layer)) + elif not fuse_qkv: + quantize_layers.append((name, new_layer)) + if fuse_qkv and name not in modules_to_not_convert and module.__class__.__name__ == "Attention": + # Replace Attention layers with BitLinearFusedAttention + from mlx_lm.models.bitlinear_layers import BitLinearFusedAttention + + new_module = BitLinearFusedAttention(model.args, invert_weight_scales=True) + quantize_layers.append((name, new_module)) if len(quantize_layers) > 0: model.update_modules(tree_unflatten(quantize_layers)) - return model \ No newline at end of file + return model + +def bitnet_sanitze(model, weights): + # Remove unused precomputed rotary freqs and handle QKV fusion + sanitized = {} + processed_layers = set() # Track which layers we've already processed + + for k, v in weights.items(): + if "self_attn.rotary_emb.inv_freq" in k: + continue + + if "self_attn" in k and ("o_proj" not in k and "attn_sub_norm" not in k): + # Extract layer prefix + prefix = k.split("self_attn")[0] + + # Only process each layer once + if prefix not in processed_layers: + processed_layers.add(prefix) + + # Handle QKV fusion for weights + if f"{prefix}self_attn.q_proj.weight" in weights: + q = weights[f"{prefix}self_attn.q_proj.weight"] + k = weights[f"{prefix}self_attn.k_proj.weight"] + v = weights[f"{prefix}self_attn.v_proj.weight"] + # import pdb; pdb.set_trace() + qkv = mx.concatenate([q, k, v], axis=0) + sanitized[f"{prefix}self_attn.qkv_proj.weight"] = qkv + + # Handle weight scales if they exist + if f"{prefix}self_attn.q_proj.weight_scale" in weights: + q_scale = weights[f"{prefix}self_attn.q_proj.weight_scale"] + k_scale = weights[f"{prefix}self_attn.k_proj.weight_scale"] + v_scale = weights[f"{prefix}self_attn.v_proj.weight_scale"] + sanitized[f"{prefix}self_attn.qkv_proj.weight_scale"] = mx.concatenate([q_scale, k_scale, v_scale], axis=0) + + # Handle biases if they exist + if f"{prefix}self_attn.q_proj.bias" in weights: + q_bias = weights[f"{prefix}self_attn.q_proj.bias"] + k_bias = weights[f"{prefix}self_attn.k_proj.bias"] + v_bias = weights[f"{prefix}self_attn.v_proj.bias"] + qkv_bias = mx.concatenate([q_bias, k_bias, v_bias], axis=0) + sanitized[f"{prefix}self_attn.qkv_proj.bias"] = qkv_bias + + # Skip the individual q/k/v components since we've fused them + continue + else: + sanitized[k] = v + + if model.args.tie_word_embeddings: + sanitized.pop("lm_head.weight", None) + return sanitized \ No newline at end of file diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 50f1283a8..aa38ef9bf 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -42,7 +42,7 @@ from .tuner.utils import get_total_parameters, load_adapters # Quant imports -from .quant.utils import replace_linear_with_quant_linear +from .quant.utils import replace_linear_with_quant_linear, bitnet_sanitze # Constants MODEL_REMAPPING = { @@ -213,12 +213,16 @@ def class_predicate(p, m): modules_to_not_convert = quantization_config.get("modules_to_not_convert", None) if quant_method is not None and quant_method in SUPPORTED_HF_QUANTIZATIONS: + fuse_qkv = True # Replace linear layers with quantized versions model = replace_linear_with_quant_linear( model, quant_method=quant_method, - modules_to_not_convert=modules_to_not_convert + modules_to_not_convert=modules_to_not_convert, + fuse_qkv=fuse_qkv ) + if fuse_qkv: + weights = bitnet_sanitze(model, weights) model.load_weights(list(weights.items()), strict=strict)