diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index b9686e733..c724256c2 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,5 +8,5 @@ with a short description of your contribution(s) below. For example: MLX LM was developed with contributions from the following individuals: - Shunta Saito: Added support for PLaMo models. -- Prince Canuma: Helped add support for `Starcoder2` models. +- Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`. - Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, and Allenai's `OLMoE`; Added support for the following training algorithms: `full-fine-tuning`; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`. diff --git a/mlx_lm/models/bitlinear_layers.py b/mlx_lm/models/bitlinear_layers.py new file mode 100644 index 000000000..9d2f902fa --- /dev/null +++ b/mlx_lm/models/bitlinear_layers.py @@ -0,0 +1,274 @@ +import mlx.core as mx +import mlx.nn as nn + + +class BitLinear(nn.Module): + """ + BitLinear module with memory-efficient weight handling. + """ + def __init__(self, in_features, out_features, bias=True, dtype=mx.float16, invert_weight_scales = False, fuse_qkv = False): + super().__init__() + self.dtype = dtype + self.in_features = in_features + self.out_features = out_features + self.fuse_qkv = fuse_qkv + + # Calculate packed dimensions - the first dimension gets packed 4:1 + # The weights are ternary so can be represented with 2 bits, + # and they are packed in uint8 tensors, hence the number of values per item is 4 + packed_out_features = (out_features + 3) // 4 + self.weight = mx.zeros((packed_out_features, in_features), dtype=mx.uint8) + + self.invert_weight_scales = invert_weight_scales + + if fuse_qkv: + self.weight_scale = mx.ones((3,), dtype=dtype) + else: + self.weight_scale = mx.array([1.0], dtype=dtype) + + if bias: + self.bias = mx.zeros((out_features,), dtype=dtype) + else: + self.bias = None + + # Add kernel caches + self._compiled_kernel = None + self._compiled_qkv_kernel = None + + def bitlinear_kernel(self, x, packed_weights, out_features=None, scale=None): + """ + Custom Metal kernel that performs matrix multiplication directly on packed weights and scales the output. + This eliminates the need to store unpacked weights in memory. + """ + source = """ + uint tid = thread_position_in_grid.x; + uint total_elements = batch_size * out_features; + + if (tid >= total_elements) return; + + uint batch_idx = tid / out_features; + uint out_idx = tid % out_features; + + float sum = 0.0; + + // Calculate packed dimensions + uint packed_rows = out_features / 4; // Each packed row contains 4 output rows + + for (uint i = 0; i < in_features; i++) { + // Get input value + float x_val = x[batch_idx * in_features + i]; + + // Determine which packed row and which bit position within that packed value + uint which_slice = out_idx / packed_rows; // Which of the 4 slices (0, 1, 2, 3) + uint row_in_slice = out_idx % packed_rows; // Which row within that slice + + // Get the packed weight value + uint packed_idx = row_in_slice * in_features + i; + uint8_t packed_val = packed_weights[packed_idx]; + + // Extract the 2-bit value for this slice + uint8_t mask = 3 << (2 * which_slice); // 0b11 shifted to the right position + uint8_t weight_bits = (packed_val & mask) >> (2 * which_slice); + + // Convert from {0,1,2} back to {-1,0,1} + float weight_val = float(weight_bits) - 1.0; + + sum += x_val * weight_val; + } + + // Apply weight scaling by diving them or multiplying them + if (invert_weight_scales) { + out[tid] = sum / weight_scale[0]; + } else { + out[tid] = sum * weight_scale[0]; + } + """ + + # Handle multi-dimensional inputs by flattening all but the last dimension + original_shape = x.shape + if len(original_shape) > 2: + # Flatten to (total_batch_elements, in_features) + x_flattened = x.reshape(-1, original_shape[-1]) + total_batch_elements = x_flattened.shape[0] + in_features = x_flattened.shape[1] + else: + x_flattened = x + total_batch_elements, in_features = x_flattened.shape + + out_features = out_features if self.fuse_qkv else self.out_features + + # Compile kernel once and cache it + if self._compiled_kernel is None: + self._compiled_kernel = mx.fast.metal_kernel( + name="bitlinear_matmul", + input_names=["x", "packed_weights", "weight_scale", "invert_weight_scales"], + output_names=["out"], + source=source, + ) + + outputs = self._compiled_kernel( + inputs=[x_flattened.astype(self.dtype), packed_weights, scale, self.invert_weight_scales], + template=[("batch_size", total_batch_elements), ("in_features", in_features), ("out_features", out_features)], + grid=(total_batch_elements * out_features, 1, 1), + threadgroup=(min(32, total_batch_elements * out_features), 1, 1), + output_shapes=[(total_batch_elements, out_features)], + output_dtypes=[self.dtype], + ) + + # Reshape output back to match input shape but with out_features as last dimension + if len(original_shape) > 2: + output_shape = original_shape[:-1] + (out_features,) + return outputs[0].reshape(output_shape) + else: + return outputs[0] + + def bitlinear_fused_qkv_kernel(self, x, packed_weights, scales): + """ + Custom Metal kernel that performs fused QKV computation in parallel. + Handles Q (2560), K (640), V (640) outputs with their respective scales. + """ + source = """ + uint tid = thread_position_in_grid.x; + uint total_elements = batch_size * total_out_features; + + if (tid >= total_elements) return; + + uint batch_idx = tid / total_out_features; + uint out_idx = tid % total_out_features; + + // Determine which component (Q, K, or V) this thread is computing + uint component; + uint local_out_idx; + uint component_start_idx; + uint component_out_features; + uint weight_slice_start; + + if (out_idx < q_features) { + // Q component + component = 0; + local_out_idx = out_idx; + component_start_idx = 0; + component_out_features = q_features; + weight_slice_start = 0; + } else if (out_idx < q_features + k_features) { + // K component + component = 1; + local_out_idx = out_idx - q_features; + component_start_idx = q_features; + component_out_features = k_features; + weight_slice_start = q_packed_rows; + } else { + // V component + component = 2; + local_out_idx = out_idx - q_features - k_features; + component_start_idx = q_features + k_features; + component_out_features = v_features; + weight_slice_start = q_packed_rows + k_packed_rows; + } + + float sum = 0.0; + + // Calculate packed dimensions for this component + uint component_packed_rows = (component_out_features + 3) / 4; + + for (uint i = 0; i < in_features; i++) { + // Get input value + float x_val = x[batch_idx * in_features + i]; + + // Determine which packed row and bit position within that packed value + uint which_slice = local_out_idx / component_packed_rows; + uint row_in_slice = local_out_idx % component_packed_rows; + + // Calculate the actual packed weight index for this component + uint packed_idx = (weight_slice_start + row_in_slice) * in_features + i; + uint8_t packed_val = packed_weights[packed_idx]; + + // Extract the 2-bit value for this slice + uint8_t mask = 3 << (2 * which_slice); // 0b11 shifted to the right position + uint8_t weight_bits = (packed_val & mask) >> (2 * which_slice); + + // Convert from {0,1,2} back to {-1,0,1} + float weight_val = float(weight_bits) - 1.0; + + sum += x_val * weight_val; + } + + // Apply component-specific weight scaling + if (invert_weight_scales) { + out[tid] = sum / scales[component]; + } else { + out[tid] = sum * scales[component]; + } + """ + + # Handle multi-dimensional inputs by flattening all but the last dimension + original_shape = x.shape + if len(original_shape) > 2: + x_flattened = x.reshape(-1, original_shape[-1]) + total_batch_elements = x_flattened.shape[0] + in_features = x_flattened.shape[1] + else: + x_flattened = x + total_batch_elements, in_features = x_flattened.shape + + # QKV dimensions (based on your split points [640, 800]) + q_features = 2560 + k_features = 640 + v_features = 640 + total_out_features = q_features + k_features + v_features + + # Calculate packed row counts for weight indexing + q_packed_rows = (q_features + 3) // 4 + k_packed_rows = (k_features + 3) // 4 + + # Compile kernel once and cache it + if self._compiled_qkv_kernel is None: + self._compiled_qkv_kernel = mx.fast.metal_kernel( + name="bitlinear_fused_qkv", + input_names=["x", "packed_weights", "scales", "invert_weight_scales"], + output_names=["out"], + source=source, + ) + + outputs = self._compiled_qkv_kernel( + inputs=[x_flattened.astype(self.dtype), packed_weights, scales, self.invert_weight_scales], + template=[ + ("batch_size", total_batch_elements), + ("in_features", in_features), + ("total_out_features", total_out_features), + ("q_features", q_features), + ("k_features", k_features), + ("v_features", v_features), + ("q_packed_rows", q_packed_rows), + ("k_packed_rows", k_packed_rows) + ], + grid=(total_batch_elements * total_out_features, 1, 1), + threadgroup=(min(32, total_batch_elements * total_out_features), 1, 1), + output_shapes=[(total_batch_elements, total_out_features)], + output_dtypes=[self.dtype], + ) + + # Reshape output back to match input shape but with total_out_features as last dimension + if len(original_shape) > 2: + output_shape = original_shape[:-1] + (total_out_features,) + return outputs[0].reshape(output_shape) + else: + return outputs[0] + + def __call__(self, x): + """ + Forward pass with weight scaling applied correctly. + """ + org_dtype = x.dtype + + # Choose the appropriate kernel based on whether this is fused QKV + if self.fuse_qkv: + y = self.bitlinear_fused_qkv_kernel(x, self.weight, self.weight_scale) + else: + y = self.bitlinear_kernel(x, self.weight, scale=self.weight_scale) + + # Add bias if present + if self.bias is not None: + y = mx.add(y, self.bias) + + return y.astype(org_dtype) \ No newline at end of file diff --git a/mlx_lm/models/bitnet.py b/mlx_lm/models/bitnet.py new file mode 100644 index 000000000..779ba8d41 --- /dev/null +++ b/mlx_lm/models/bitnet.py @@ -0,0 +1,270 @@ +# Copyright 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from functools import partial +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope +from .bitlinear_layers import BitLinear + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.n_rep = n_heads // n_kv_heads + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + attention_bias = getattr(args, "attention_bias", False) + + # Single QKV projection + self.qkv_proj = BitLinear( + dim, + (n_heads + 2 * n_kv_heads) * head_dim, + bias=attention_bias, + fuse_qkv=True + ) + self.o_proj = BitLinear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + self.attn_sub_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + # Fused QKV projection + qkv = self.qkv_proj(x) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + output = self.attn_sub_norm(output) + output = self.o_proj(output) + + return output + +@partial(mx.compile, shapeless=True) +def relu2(x): + return mx.square(nn.relu(x)) + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = BitLinear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = BitLinear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = BitLinear(dim, hidden_dim, bias=mlp_bias) + self.ffn_sub_norm = nn.RMSNorm(args.intermediate_size, eps=args.rms_norm_eps) + + def __call__(self, x) -> mx.array: + x = relu2(self.gate_proj(x)) * self.up_proj(x) + x = self.ffn_sub_norm(x) + x = self.down_proj(x) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) + + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs and handle QKV fusion + sanitized_weights = {} + processed_layers = set() # Track which layers we've already processed + + for k, v in weights.items(): + if "self_attn.rotary_emb.inv_freq" in k: + continue + + if "self_attn" in k and ("o_proj" not in k and "attn_sub_norm" not in k): + # Extract layer prefix + prefix = k.split("self_attn")[0] + + # Only process each layer once + if prefix not in processed_layers: + processed_layers.add(prefix) + + # Handle QKV fusion for weights + if f"{prefix}self_attn.q_proj.weight" in weights: + q = weights[f"{prefix}self_attn.q_proj.weight"] + k = weights[f"{prefix}self_attn.k_proj.weight"] + v = weights[f"{prefix}self_attn.v_proj.weight"] + qkv = mx.concatenate([q, k, v], axis=0) + sanitized_weights[f"{prefix}self_attn.qkv_proj.weight"] = qkv + + # Handle weight scales if they exist + if f"{prefix}self_attn.q_proj.weight_scale" in weights: + q_scale = weights[f"{prefix}self_attn.q_proj.weight_scale"] + k_scale = weights[f"{prefix}self_attn.k_proj.weight_scale"] + v_scale = weights[f"{prefix}self_attn.v_proj.weight_scale"] + + sanitized_weights[f"{prefix}self_attn.qkv_proj.weight_scale"] = mx.concatenate([q_scale, k_scale, v_scale], axis=0) + + + # Handle biases if they exist + if f"{prefix}self_attn.q_proj.bias" in weights: + q_bias = weights[f"{prefix}self_attn.q_proj.bias"] + k_bias = weights[f"{prefix}self_attn.k_proj.bias"] + v_bias = weights[f"{prefix}self_attn.v_proj.bias"] + qkv_bias = mx.concatenate([q_bias, k_bias, v_bias], axis=0) + sanitized_weights[f"{prefix}self_attn.qkv_proj.bias"] = qkv_bias + + # Skip the individual q/k/v components since we've fused them + continue + else: + sanitized_weights[k] = v + + if self.args.tie_word_embeddings: + sanitized_weights.pop("lm_head.weight", None) + return sanitized_weights + @property + def layers(self): + return self.model.layers diff --git a/mlx_lm/quant/utils.py b/mlx_lm/quant/utils.py index dc3801f06..4acbad783 100644 --- a/mlx_lm/quant/utils.py +++ b/mlx_lm/quant/utils.py @@ -3,7 +3,14 @@ from pathlib import Path import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from ..models.bitlinear_layers import BitLinear + +QUANT_LINEAR_MAPPING = { + 'bitnet': BitLinear, +} def load_data(tokenizer, num_samples: int, sequence_length: int) -> mx.array: save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt" @@ -24,3 +31,23 @@ def load_data(tokenizer, num_samples: int, sequence_length: int) -> mx.array: if num_samples > 0: segments = segments[:num_samples] return tokens[segments] + +def replace_linear_with_quant_linear(model, quant_method = "bitnet", modules_to_not_convert=None): + quantize_layers = [] + for name, module in model.named_modules(): + if modules_to_not_convert is None: + modules_to_not_convert = [] + + # Replace nn.Linear layers, but skip 'lm_head' + if name not in modules_to_not_convert and isinstance(module, nn.Linear): + old_weight = module.weight + out_features, in_features = old_weight.shape + bias = "bias" in module + # Create a new instance of the custom linear layer + new_layer = QUANT_LINEAR_MAPPING[quant_method](in_features, out_features, bias=bias, invert_weight_scales=True) + + # Replace the layer in the model + quantize_layers.append((name, new_layer)) + if len(quantize_layers) > 0: + model.update_modules(tree_unflatten(quantize_layers)) + return model \ No newline at end of file diff --git a/mlx_lm/tuner/utils.py b/mlx_lm/tuner/utils.py index 1f6d10de6..a6d76abbb 100644 --- a/mlx_lm/tuner/utils.py +++ b/mlx_lm/tuner/utils.py @@ -13,7 +13,6 @@ from .dora import DoRAEmbedding, DoRALinear from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear - def build_schedule(schedule_config: Dict): """ Build a learning rate schedule from the given config. @@ -81,6 +80,7 @@ def to_lora(layer): "mistral", "mistral3", "llama", + "bitnet", "phi", "mixtral", "nemotron", @@ -285,4 +285,4 @@ def print_trainable_parameters(model): print( f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"({trainable_p:.3f}M/{total_p:.3f}M)" - ) + ) \ No newline at end of file diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 690e8d7f4..50f1283a8 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -41,6 +41,9 @@ from .tuner.utils import dequantize as dequantize_model from .tuner.utils import get_total_parameters, load_adapters +# Quant imports +from .quant.utils import replace_linear_with_quant_linear + # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama @@ -50,6 +53,9 @@ MAX_FILE_SIZE_GB = 5 +SUPPORTED_HF_QUANTIZATIONS = [ + "bitnet" +] def _get_classes(config: dict): """ @@ -159,6 +165,7 @@ def load_model( config = load_config(model_path) config.update(model_config) + weight_files = glob.glob(str(model_path / "model*.safetensors")) if not weight_files: @@ -181,8 +188,8 @@ def load_model( if hasattr(model, "sanitize"): weights = model.sanitize(weights) + # This handles the case where we use MLX-related quantizations if (quantization := config.get("quantization", None)) is not None: - def class_predicate(p, m): # Handle custom per layer quantizations if p in config["quantization"]: @@ -199,6 +206,20 @@ def class_predicate(p, m): class_predicate=class_predicate, ) + # We can also handle HF-related quant models such as bitnet + if config.get("quantization_config", None) is not None: + quantization_config = config["quantization_config"] + quant_method = quantization_config.get("quant_method", None) + modules_to_not_convert = quantization_config.get("modules_to_not_convert", None) + + if quant_method is not None and quant_method in SUPPORTED_HF_QUANTIZATIONS: + # Replace linear layers with quantized versions + model = replace_linear_with_quant_linear( + model, + quant_method=quant_method, + modules_to_not_convert=modules_to_not_convert + ) + model.load_weights(list(weights.items()), strict=strict) if not lazy: @@ -453,7 +474,8 @@ def quantize_model( if "quantization" in config: raise ValueError("Cannot quantize already quantized model") quantized_config = copy.deepcopy(config) - quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + quant_method = quantized_config.get("quantization_config", {}) + quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits, **quant_method} # Add any custom quantization parameters to the config as we go def _class_predicate(p, m): diff --git a/tests/test_models.py b/tests/test_models.py index 80c4d7add..937861715 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -247,6 +247,23 @@ def test_llama(self): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_bitnet(self): + from mlx_lm.models import bitnet + + args = bitnet.ModelArgs( + model_type="bitnet", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = bitnet.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_phi2(self): from mlx_lm.models import phi