Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
6f0e85b
add bitnet
Blaizzy Apr 17, 2025
2fa462f
update activation to relu2
Blaizzy Apr 17, 2025
0e9628c
working bitnet
Blaizzy Apr 18, 2025
3e7c1a9
remove artifacts
Blaizzy Apr 18, 2025
29151ed
remove logging
Blaizzy Apr 18, 2025
eea94a8
add custom post quant
Blaizzy Apr 18, 2025
ae98be8
fix dtype and add compile
Blaizzy Jun 8, 2025
c137f3c
fixed weight unpack
Blaizzy Jun 8, 2025
eb3846e
add custom kernel to avoid memory overhead
Blaizzy Jun 8, 2025
026b600
compile relu2
Blaizzy Jun 8, 2025
5ed1b1c
fix weight scale
Blaizzy Jun 8, 2025
058c792
remove unused
Blaizzy Jun 8, 2025
cd1783d
Merge branch 'ml-explore:main' into pc/add-bitnet
Blaizzy Jun 8, 2025
5a8e952
add tests and update tuner utils
Blaizzy Jun 8, 2025
e893d85
update acknowledgements
Blaizzy Jun 8, 2025
80e8ce5
add kernel caching
Blaizzy Jun 9, 2025
bd58d3f
add act_quant and set float16 as default dtype
Blaizzy Jun 9, 2025
c89491c
use mx.add and move scaling to kernel
Blaizzy Jun 9, 2025
5d816e8
remove act quant
Blaizzy Jun 9, 2025
1a076ea
move bitlinear layers to separate file
Blaizzy Jun 10, 2025
ec416eb
feat: add falcon-e and other bitnet support
younesbelkada Jun 10, 2025
f3b84e5
refactor: address comments
younesbelkada Jun 10, 2025
a9f257c
Merge pull request #1 from younesbelkada/add-falcon-e
Blaizzy Jun 10, 2025
fb40d51
add support for 1.58bit N-bit quants
Blaizzy Jun 10, 2025
3aaba20
43.85% speedup in generation performance (M3 max)
Blaizzy Jun 12, 2025
3a5e4f9
refactor utils
Blaizzy Jun 12, 2025
3a8136f
remove masking (2% gen speed improvement)
Blaizzy Jun 12, 2025
ae5a6a0
add quantization config
Blaizzy Jun 12, 2025
be68e46
test llama bitnet
Blaizzy Jun 12, 2025
3d5422b
refactor apply_hf_quant
Blaizzy Jun 12, 2025
feae07b
default threadgroup: 64 -> 32
Blaizzy Jun 12, 2025
3812c27
add comment
Blaizzy Jun 12, 2025
9207a39
fix prompt processing perf
Blaizzy Jun 12, 2025
a0b4026
remove modulo
Blaizzy Jun 12, 2025
4fab6fc
compile kernel in the constructor
Blaizzy Jun 13, 2025
239072e
add fused kernel
Blaizzy Jun 15, 2025
8f20baa
rename
Blaizzy Jun 15, 2025
9a8b9e8
refactor
Blaizzy Jun 15, 2025
0391d22
Increase lanes from 4 to 8
Blaizzy Jun 15, 2025
c007c1f
feat: add fused QKV for other bitnet models
younesbelkada Jun 15, 2025
2180426
refactor compiled kernel
Blaizzy Jun 15, 2025
f21ec19
address all comments
younesbelkada Jun 15, 2025
02ee4f4
Merge pull request #6 from younesbelkada/add-fused-falcon-e
Blaizzy Jun 21, 2025
594f4eb
Improve the bitnet kernel
angeloskath Jun 26, 2025
4e7a8e0
remove benchmark
Blaizzy Jun 26, 2025
5700315
refactor bitlinear swap
Blaizzy Jun 26, 2025
2ffcb79
format
Blaizzy Jun 26, 2025
c31f737
Merge branch 'pc/add-bitnet' of https://github.com/Blaizzy/mlx-lm int…
Blaizzy Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ with a short description of your contribution(s) below. For example:
MLX LM was developed with contributions from the following individuals:

- Shunta Saito: Added support for PLaMo models.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`.
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, and Allenai's `OLMoE`; Added support for the following training algorithms: `full-fine-tuning`; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`.
175 changes: 175 additions & 0 deletions mlx_lm/models/bitlinear_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import Optional, Any

import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.quantized import QuantizedLinear

from .base import scaled_dot_product_attention
from .rope_utils import initialize_rope

class BitLinear(nn.Module):
"""
BitLinear module with memory-efficient weight handling.
"""

def __init__(
self,
in_features,
out_features,
bias=True,
dtype=mx.float16,
invert_weight_scales=False,
):
super().__init__()
self.dtype = dtype
self.in_features = in_features
self.out_features = out_features
# Calculate packed dimensions - the first dimension gets packed 4:1
# The weights are ternary so can be represented with 2 bits,
# and they are packed in uint8 tensors, hence the number of values per item is 4
packed_out_features = (out_features + 3) // 4
self.weight = mx.zeros((packed_out_features, in_features), dtype=mx.uint8)

self.invert_weight_scales = invert_weight_scales
self.fused_qkv = fused_qkv

shape = [1.0, 1.0, 1.0] if fused_qkv else [1.0]
self.weight_scale = mx.array(shape, dtype=dtype)

if bias:
self.bias = mx.zeros((out_features,), dtype=dtype)
else:
self.bias = None

self._compiled_kernel = self._compile_qkv_kernel() if fused_qkv else self._compile_matmul_kernel()


def _compile_matmul_kernel(self):
"""
Custom Metal kernel that performs matrix multiplication directly on packed weights and scales the output.
This eliminates the need to store unpacked weights in memory.
"""
source = """
constexpr int M = 4;
constexpr int BLOCK = 32;

uint tid = thread_position_in_grid.y;
uint in_offset = thread_position_in_grid.x;

uint batch_idx = tid / (out_features / 4);
uint row_idx = tid % (out_features / 4);

float sum[4] = {0.0};

for (uint i = in_offset * M; i < in_features; i += BLOCK * M) {
float v[M];
for (int j=0; j<M; j++) {
v[j] = x[batch_idx * in_features + i + j];
}

for (int j=0; j<M; j++) {
uint8_t w = packed_weights[row_idx * in_features + i + j];
sum[0] += v[j] * ((w & 3) - 1);
sum[1] += v[j] * (((w >> 2) & 3) - 1);
sum[2] += v[j] * (((w >> 4) & 3) - 1);
sum[3] += v[j] * (((w >> 6) & 3) - 1);
}
}

for (int j=0; j<4; j++) {
sum[j] = simd_sum(sum[j]);
}


// Apply weight scaling by diving them or multiplying them
if (in_offset == 0) {
float scale = invert_weight_scales ? 1 / weight_scale[0] : weight_scale[0];
for (int i=0; i<4; i++) {
out[batch_idx * out_features + row_idx + i * (out_features/4)] = sum[i] * scale;
}
}
"""
return mx.fast.metal_kernel(
name="bitlinear_matmul_8unroll",
input_names=["x", "packed_weights", "weight_scale", "invert_weight_scales"],
output_names=["out"],
source=source,
)

def execute_matmul_kernel(self, x, packed_weights):
# Handle multi-dimensional inputs by flattening all but the last dimension
original_shape = x.shape
if len(original_shape) > 2:
# Flatten to (total_batch_elements, in_features)
x_flattened = x.reshape(-1, original_shape[-1])
total_batch_elements = x_flattened.shape[0]
in_features = x_flattened.shape[1]
else:
x_flattened = x
total_batch_elements, in_features = x_flattened.shape

out_features = self.out_features

outputs = self._compiled_kernel(
inputs=[
x_flattened.astype(self.dtype),
packed_weights,
self.weight_scale,
self.invert_weight_scales,
],
template=[
("batch_size", total_batch_elements),
("in_features", in_features),
("out_features", out_features),
],
grid=(32, total_batch_elements * out_features // 4, 1),
threadgroup=(32, 1, 1), # SIMD width is 32 threads
output_shapes=[(total_batch_elements, out_features)],
output_dtypes=[self.dtype],
)

# Reshape output back to match input shape but with out_features as last dimension
if len(original_shape) > 2:
output_shape = original_shape[:-1] + (out_features,)
return outputs[0].reshape(output_shape)
else:
return outputs[0]

def __call__(self, x):
"""
Forward pass with weight scaling applied correctly.
"""
org_dtype = x.dtype

# Use custom kernel for matrix multiplication directly on packed weights
y = self.execute_matmul_kernel(x, self.weight)

# Add bias if present
if self.bias is not None:
y = mx.add(y, self.bias)
return y.astype(org_dtype)


class QuantAndBitLinear(nn.Linear):
"""
A Linear layer that can be converted to a quantized and bitlinear version.
"""

def to_quantized(
self, method: str = None, group_size: int = 64, bits: int = 4, **kwargs
):

if method is None or group_size is None or bits is None:
return QuantizedLinear.from_linear(self, group_size, bits)

if method == "bitnet":
bitlinear = BitLinear(
in_features=self.weight.shape[1],
out_features=self.weight.shape[0],
bias=getattr(self, "bias", None) is not None,
invert_weight_scales=True,
**kwargs,
)
return bitlinear
else:
raise ValueError(f"Unknown quantization method: {method}")
Loading