Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 115 additions & 12 deletions mlx_lm/models/bitlinear_layers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import Any, Optional

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, scaled_dot_product_attention
from .rope_utils import initialize_rope

class BitLinear(nn.Module):
"""
BitLinear module with memory-efficient weight handling.
"""
def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, invert_weight_scales = False):
def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, invert_weight_scales = False, fused_shapes = None):
super().__init__()
self.dtype = dtype
self.in_features = in_features
Expand All @@ -18,7 +22,14 @@ def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, inver
self.weight = mx.zeros((packed_out_features, in_features), dtype=mx.uint8)

self.invert_weight_scales = invert_weight_scales
self.weight_scale = mx.array([1.0], dtype=dtype)
self.fused_shapes = fused_shapes

if fused_shapes is None:
self.weight_scale = mx.array([1.0], dtype=dtype)
else:
self.weight_scale = mx.array([1.0] * (len(fused_shapes) + 1), dtype=dtype)

self.fused_layers = fused_shapes is not None

if bias:
self.bias = mx.zeros((out_features,), dtype=dtype)
Expand Down Expand Up @@ -47,14 +58,14 @@ def bitlinear_kernel(self, x, packed_weights):
// Calculate packed dimensions
uint packed_rows = out_features / 4; // Each packed row contains 4 output rows

// Determine which packed row and which bit position within that packed value
uint which_slice = out_idx / packed_rows; // Which of the 4 slices (0, 1, 2, 3)
uint row_in_slice = out_idx % packed_rows; // Which row within that slice

for (uint i = 0; i < in_features; i++) {
// Get input value
float x_val = x[batch_idx * in_features + i];

// Determine which packed row and which bit position within that packed value
uint which_slice = out_idx / packed_rows; // Which of the 4 slices (0, 1, 2, 3)
uint row_in_slice = out_idx % packed_rows; // Which row within that slice

// Get the packed weight value
uint packed_idx = row_in_slice * in_features + i;
uint8_t packed_val = packed_weights[packed_idx];
Expand All @@ -69,11 +80,23 @@ def bitlinear_kernel(self, x, packed_weights):
sum += x_val * weight_val;
}

uint weight_layer_idx = 0; // Single layer case without any fusing

// Apply weight scaling by diving them or multiplying them
if (fused_length > 0) {
// determine the index by checking the interval which weight_layer_idx belongs to
for (uint i = 0; i < fused_length; i++) {
if (out_idx < fused_shapes[i]) {
weight_layer_idx = i;
break;
}
}
}

if (invert_weight_scales) {
out[tid] = sum / weight_scale[0];
out[tid] = sum / weight_scale[weight_layer_idx];
} else {
out[tid] = sum * weight_scale[0];
out[tid] = sum * weight_scale[weight_layer_idx];
}
"""

Expand All @@ -94,16 +117,22 @@ def bitlinear_kernel(self, x, packed_weights):
if self._compiled_kernel is None:
self._compiled_kernel = mx.fast.metal_kernel(
name="bitlinear_matmul",
input_names=["x", "packed_weights", "weight_scale", "invert_weight_scales"],
input_names=["x", "packed_weights", "weight_scale", "invert_weight_scales", "fused_shapes", "fused_length"],
output_names=["out"],
source=source,
)

if self.fused_layers:
self.fused_shapes.append(out_features)
inputs = [x_flattened.astype(self.dtype), packed_weights, self.weight_scale, self.invert_weight_scales, mx.array(self.fused_shapes), len(self.fused_shapes)]
else:
inputs = [x_flattened.astype(self.dtype), packed_weights, self.weight_scale, self.invert_weight_scales, mx.array([]), 0]

outputs = self._compiled_kernel(
inputs=[x_flattened.astype(self.dtype), packed_weights, self.weight_scale, self.invert_weight_scales],
inputs=inputs,
template=[("batch_size", total_batch_elements), ("in_features", in_features), ("out_features", out_features)],
grid=(total_batch_elements * out_features, 1, 1),
threadgroup=(min(256, total_batch_elements * out_features), 1, 1),
threadgroup=(min(32, total_batch_elements * out_features), 1, 1),
output_shapes=[(total_batch_elements, out_features)],
output_dtypes=[self.dtype],
)
Expand All @@ -130,4 +159,78 @@ def __call__(self, x):
y = mx.add(y, self.bias)

return y.astype(org_dtype)



class BitLinearFusedAttention(nn.Module):
def __init__(self, args, invert_weight_scales: bool = True):
super().__init__()

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads

self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads

self.scale = head_dim**-0.5

if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False

query_pos = n_heads * head_dim

self.qkv_proj = BitLinear(
dim,
(n_heads + 2 * n_kv_heads) * head_dim,
bias=attention_bias,
fused_shapes=[query_pos, query_pos + self.n_kv_heads * self.head_dim],
invert_weight_scales=invert_weight_scales
)
self.o_proj = BitLinear(
n_heads * head_dim,
dim,
bias=attention_bias,
invert_weight_scales=invert_weight_scales
)

self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape

qkv = self.qkv_proj(x)
query_pos = self.n_heads * self.head_dim
queries, keys, values = mx.split(
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
9 changes: 6 additions & 3 deletions mlx_lm/models/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@ def __init__(self, args: ModelArgs):
self.scale = head_dim**-0.5
attention_bias = getattr(args, "attention_bias", False)

query_pos = n_heads * head_dim

# Single QKV projection
self.qkv_proj = BitLinear(
dim,
(n_heads + 2 * n_kv_heads) * head_dim,
bias=attention_bias
bias=attention_bias,
fused_shapes=[query_pos, query_pos + self.n_kv_heads * self.head_dim],
)
self.o_proj = BitLinear(n_heads * head_dim, dim, bias=attention_bias)

Expand Down Expand Up @@ -244,8 +247,8 @@ def sanitize(self, weights):
q_scale = weights[f"{prefix}self_attn.q_proj.weight_scale"]
k_scale = weights[f"{prefix}self_attn.k_proj.weight_scale"]
v_scale = weights[f"{prefix}self_attn.v_proj.weight_scale"]
qkv_scale = mx.sqrt((q_scale**2 + k_scale**2 + v_scale**2) / 3) # Root mean square
sanitized[f"{prefix}self_attn.qkv_proj.weight_scale"] = qkv_scale
# qkv_scale = mx.sqrt((q_scale**2 + k_scale**2 + v_scale**2) / 3) # Root mean square
sanitized[f"{prefix}self_attn.qkv_proj.weight_scale"] = mx.concatenate([q_scale, k_scale, v_scale], axis=0)

# Handle biases if they exist
if f"{prefix}self_attn.q_proj.bias" in weights:
Expand Down
1 change: 1 addition & 0 deletions mlx_lm/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)


if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
Expand Down
66 changes: 62 additions & 4 deletions mlx_lm/quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ def load_data(tokenizer, num_samples: int, sequence_length: int) -> mx.array:
segments = segments[:num_samples]
return tokens[segments]

def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_not_convert=None):
def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_not_convert=None, fuse_qkv=False):
quantize_layers = []
for name, module in model.named_modules():
if modules_to_not_convert is None:
modules_to_not_convert = []

# Replace nn.Linear layers, but skip 'lm_head'
if name not in modules_to_not_convert and isinstance(module, nn.Linear):
old_weight = module.weight
Expand All @@ -47,7 +46,66 @@ def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_
new_layer = QUANT_LINEAR_MAPPING[quant_method](in_features, out_features, bias=bias, invert_weight_scales=True)

# Replace the layer in the model
quantize_layers.append((name, new_layer))
if fuse_qkv and not any(name.endswith(suffix) for suffix in ["q_proj", "k_proj", "v_proj", "o_proj"]):
quantize_layers.append((name, new_layer))
elif not fuse_qkv:
quantize_layers.append((name, new_layer))
if fuse_qkv and name not in modules_to_not_convert and module.__class__.__name__ == "Attention":
# Replace Attention layers with BitLinearFusedAttention
from mlx_lm.models.bitlinear_layers import BitLinearFusedAttention

new_module = BitLinearFusedAttention(model.args, invert_weight_scales=True)
quantize_layers.append((name, new_module))
if len(quantize_layers) > 0:
model.update_modules(tree_unflatten(quantize_layers))
return model
return model

def bitnet_sanitze(model, weights):
# Remove unused precomputed rotary freqs and handle QKV fusion
sanitized = {}
processed_layers = set() # Track which layers we've already processed

for k, v in weights.items():
if "self_attn.rotary_emb.inv_freq" in k:
continue

if "self_attn" in k and ("o_proj" not in k and "attn_sub_norm" not in k):
# Extract layer prefix
prefix = k.split("self_attn")[0]

# Only process each layer once
if prefix not in processed_layers:
processed_layers.add(prefix)

# Handle QKV fusion for weights
if f"{prefix}self_attn.q_proj.weight" in weights:
q = weights[f"{prefix}self_attn.q_proj.weight"]
k = weights[f"{prefix}self_attn.k_proj.weight"]
v = weights[f"{prefix}self_attn.v_proj.weight"]
# import pdb; pdb.set_trace()
qkv = mx.concatenate([q, k, v], axis=0)
sanitized[f"{prefix}self_attn.qkv_proj.weight"] = qkv

# Handle weight scales if they exist
if f"{prefix}self_attn.q_proj.weight_scale" in weights:
q_scale = weights[f"{prefix}self_attn.q_proj.weight_scale"]
k_scale = weights[f"{prefix}self_attn.k_proj.weight_scale"]
v_scale = weights[f"{prefix}self_attn.v_proj.weight_scale"]
sanitized[f"{prefix}self_attn.qkv_proj.weight_scale"] = mx.concatenate([q_scale, k_scale, v_scale], axis=0)

# Handle biases if they exist
if f"{prefix}self_attn.q_proj.bias" in weights:
q_bias = weights[f"{prefix}self_attn.q_proj.bias"]
k_bias = weights[f"{prefix}self_attn.k_proj.bias"]
v_bias = weights[f"{prefix}self_attn.v_proj.bias"]
qkv_bias = mx.concatenate([q_bias, k_bias, v_bias], axis=0)
sanitized[f"{prefix}self_attn.qkv_proj.bias"] = qkv_bias

# Skip the individual q/k/v components since we've fused them
continue
else:
sanitized[k] = v

if model.args.tie_word_embeddings:
sanitized.pop("lm_head.weight", None)
return sanitized
8 changes: 6 additions & 2 deletions mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .tuner.utils import get_total_parameters, load_adapters

# Quant imports
from .quant.utils import replace_linear_with_quant_linear
from .quant.utils import replace_linear_with_quant_linear, bitnet_sanitze

# Constants
MODEL_REMAPPING = {
Expand Down Expand Up @@ -213,12 +213,16 @@ def class_predicate(p, m):
modules_to_not_convert = quantization_config.get("modules_to_not_convert", None)

if quant_method is not None and quant_method in SUPPORTED_HF_QUANTIZATIONS:
fuse_qkv = True
# Replace linear layers with quantized versions
model = replace_linear_with_quant_linear(
model,
quant_method=quant_method,
modules_to_not_convert=modules_to_not_convert
modules_to_not_convert=modules_to_not_convert,
fuse_qkv=fuse_qkv
)
if fuse_qkv:
weights = bitnet_sanitze(model, weights)

model.load_weights(list(weights.items()), strict=strict)

Expand Down