diff --git a/README.md b/README.md index 87ce1ba..edf6a71 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,10 @@ Apple M5 Max, 64 GB unified memory, MLX 0.31.1. Protocol: stock `mlx_lm.stream_g | Qwen3.6-35B-A3B-4bit | 2048 | 139.03 tok/s | 252.93 tok/s | 1.82x | 89.60% | | Qwen3.6-35B-A3B-4bit | 4096 | 134.50 tok/s | 208.40 tok/s | 1.56x | 88.43% | | Qwen3.6-35B-A3B-4bit | 8192 | 133.20 tok/s | 177.45 tok/s | 1.33x | 87.01% | +| LFM2.5-1.2B-Instruct\* | 1024 | 141.94 tok/s | 338.98 tok/s | 2.39x | 86.82% | +| LFM2.5-1.2B-Instruct\* | 2048 | 141.03 tok/s | 209.59 tok/s | 1.49x | 78.76% | + +\* Measured on Apple M4 Max, 36 GB unified memory, MLX 0.31.2. The bundled `nathanrchn/LFM2.5-1.2B-Instruct-DFlash` draft was trained on 2k-token sequences, so acceptance and speedup degrade past that horizon; 4k/8k rows are omitted. Per-run JSON: [`benchmark/results/`](benchmark/results/). Reproduce on your hardware with `dflash benchmark`. diff --git a/dflash_mlx/engine/target_lfm2.py b/dflash_mlx/engine/target_lfm2.py new file mode 100644 index 0000000..b0f26e0 --- /dev/null +++ b/dflash_mlx/engine/target_lfm2.py @@ -0,0 +1,308 @@ +# Copyright 2026 bstnxbt +# Licensed under the Apache License, Version 2.0 - see LICENSE file +# Based on DFlash (arXiv:2602.06036) + +from __future__ import annotations + +import time +from typing import Any, Optional + +import mlx.core as mx +from mlx_lm.models import cache as cache_mod +from mlx_lm.models.base import ( + create_attention_mask, + create_ssm_mask, +) + +from dflash_mlx.engine.target_ops import TargetCapabilities +from dflash_mlx.short_conv_rollback_cache import ShortConvRollbackCache + + +def _install_short_conv_rollback_hook(conv_module: Any) -> None: + cls = type(conv_module) + if getattr(cls, "_dflash_short_conv_rollback_installed", False): + return + + original_call = cls.__call__ + + def patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + if not isinstance(cache, ShortConvRollbackCache): + return original_call(self, x, mask=mask, cache=cache) + + BCx = self.in_proj(x) + B_split, C, x_val = mx.split(BCx, 3, axis=-1) + Bx = B_split * x_val + if mask is not None: + Bx = mx.where(mask[..., None], Bx, 0) + + if cache[0] is None: + state = mx.zeros( + (Bx.shape[0], self.L_cache - 1, self.args.hidden_size), + dtype=Bx.dtype, + ) + else: + state = cache[0] + Bx = mx.concatenate([state, Bx], axis=1) + + cache.record_tape_if_armed(Bx) + + n_keep = self.L_cache - 1 + t = x_val.shape[1] + if cache.lengths is not None: + ends = mx.clip(cache.lengths, 0, t) + positions = (ends[:, None] + mx.arange(n_keep))[..., None] + cache[0] = mx.take_along_axis(Bx, positions, axis=1) + else: + cache[0] = mx.contiguous(Bx[:, -n_keep:, :]) + cache.advance(t) + + conv_out = self.conv(Bx) + y = C * conv_out + return self.out_proj(y) + + cls.__call__ = patched_call + cls._dflash_short_conv_rollback_installed = True + + +class Lfm2TargetOps: + backend_name = "lfm2" + + def model_type(self, target_model: Any) -> str: + args = getattr(target_model, "args", None) + if args is None and hasattr(target_model, "language_model"): + args = getattr(target_model.language_model, "args", None) + value = getattr(args, "model_type", None) + if value is not None: + return str(value).lower() + config = getattr(target_model, "config", None) + if isinstance(config, dict): + return str(config.get("model_type", "")).lower() + return "" + + def supports_model(self, target_model: Any) -> bool: + if "lfm" not in self.model_type(target_model): + return False + return self._has_lfm_text_shape(target_model) + + def _has_lfm_text_shape(self, target_model: Any) -> bool: + try: + inner = self.text_model(target_model) + except AttributeError: + return False + if not (hasattr(inner, "layers") and hasattr(inner, "embed_tokens")): + return False + return any( + hasattr(layer, "conv") and not getattr(layer, "is_attention_layer", True) + for layer in inner.layers + ) + + def text_wrapper(self, target_model: Any) -> Any: + if hasattr(target_model, "model"): + return target_model + if hasattr(target_model, "language_model"): + return target_model.language_model + raise AttributeError(f"Unsupported target model wrapper: {type(target_model)!r}") + + def text_model(self, target_model: Any) -> Any: + wrapper = self.text_wrapper(target_model) + if hasattr(wrapper, "model"): + return wrapper.model + raise AttributeError(f"Unsupported target text model: {type(wrapper)!r}") + + def embed_tokens(self, target_model: Any) -> Any: + return self.text_model(target_model).embed_tokens + + def logits_from_hidden(self, target_model: Any, hidden_states: mx.array) -> mx.array: + return self.text_model(target_model).embed_tokens.as_linear(hidden_states) + + def family(self, target_model: Any) -> str: + return "hybrid_short_conv" + + def capabilities_for(self, target_model: Any) -> TargetCapabilities: + return TargetCapabilities( + supports_dflash=True, + supports_recurrent_rollback=True, + supports_kv_trim=True, + supports_prefix_snapshot=True, + supports_rotating_cache_snapshot=False, + supports_shared_kv=False, + supports_target_hidden_capture=True, + ) + + def extract_context_feature( + self, + captured_dict: dict[int, mx.array], + target_layer_ids: list[int], + ) -> mx.array: + selected = [captured_dict[layer_id + 1] for layer_id in target_layer_ids] + return mx.concatenate(selected, axis=-1) + + def forward_with_hidden_capture( + self, + target_model: Any, + *, + input_ids: Optional[mx.array] = None, + cache: Optional[list[Any]] = None, + input_embeddings: Optional[mx.array] = None, + capture_layer_ids: Optional[set[int]] = None, + ) -> tuple[mx.array, list[mx.array] | dict[int, mx.array]]: + inner = self.text_model(target_model) + hidden_states = ( + input_embeddings if input_embeddings is not None else inner.embed_tokens(input_ids) + ) + if cache is None: + cache = [None] * len(inner.layers) + capture_all = capture_layer_ids is None + if capture_all: + captured: list[mx.array] | dict[int, mx.array] = [hidden_states] + else: + capture_layer_ids = set(capture_layer_ids) + captured = {0: hidden_states} if 0 in capture_layer_ids else {} + h = hidden_states + + fa_mask = create_attention_mask(hidden_states, cache[inner.fa_idx]) + conv_mask = create_ssm_mask(hidden_states, cache[inner.conv_idx]) + for layer_index, (layer, layer_cache) in enumerate(zip(inner.layers, cache, strict=True)): + mask = fa_mask if getattr(layer, "is_attention_layer", False) else conv_mask + h = layer(h, mask=mask, cache=layer_cache) + capture_key = layer_index + 1 + if capture_all: + captured.append(h) + elif capture_layer_ids is not None and capture_key in capture_layer_ids: + captured[capture_key] = h + normalized = inner.embedding_norm(h) + logits = self.logits_from_hidden(target_model, normalized) + return logits, captured + + def verify_block( + self, + *, + target_model: Any, + verify_ids: mx.array, + target_cache: list[Any], + capture_layer_ids: Optional[set[int]] = None, + ) -> tuple[mx.array, list[mx.array] | dict[int, mx.array]]: + if int(verify_ids.shape[1]) <= 0: + raise ValueError("verify block must contain at least one token") + return self.forward_with_hidden_capture( + target_model, + input_ids=verify_ids, + cache=target_cache, + capture_layer_ids=capture_layer_ids, + ) + + def install_speculative_hooks(self, target_model: Any) -> None: + text_model = self.text_model(target_model) + if getattr(text_model, "_dflash_speculative_hooks_installed", False): + return + for layer in text_model.layers: + if hasattr(layer, "conv") and not getattr(layer, "is_attention_layer", True): + _install_short_conv_rollback_hook(layer.conv) + text_model._dflash_speculative_hooks_installed = True + + def configure_full_attention_split( + self, + target_model: Any, + *, + enabled: bool, + chunk_size: int = 8, + ) -> None: + return + + def make_cache( + self, + target_model: Any, + *, + enable_speculative_linear_cache: bool, + quantize_kv_cache: bool = False, + target_fa_window: Optional[int] = None, + ) -> list[Any]: + fa_window = 0 if target_fa_window is None else int(target_fa_window) + if fa_window < 0: + raise ValueError("target_fa_window must be >= 0") + if fa_window > 0 and quantize_kv_cache: + raise ValueError( + "target_fa_window does not support quantized target KV cache" + ) + text_model = self.text_model(target_model) + caches: list[Any] = [] + for layer in text_model.layers: + if hasattr(layer, "conv") and not getattr(layer, "is_attention_layer", True): + if enable_speculative_linear_cache: + self.install_speculative_hooks(target_model) + kernel_size = int(getattr(layer.conv, "L_cache", 4)) + caches.append(ShortConvRollbackCache(kernel_size=kernel_size)) + else: + caches.append(cache_mod.ArraysCache(size=1)) + else: + if fa_window > 0: + caches.append(cache_mod.RotatingKVCache(max_size=fa_window)) + elif quantize_kv_cache: + caches.append(cache_mod.QuantizedKVCache(group_size=64, bits=8)) + else: + caches.append(cache_mod.KVCache()) + return caches + + def arm_rollback(self, cache_entries: list[Any], *, prefix_len: int) -> None: + for cache_entry in cache_entries: + if hasattr(cache_entry, "arm_rollback"): + cache_entry.arm_rollback(prefix_len=int(prefix_len)) + + def clear_rollback_state(self, cache_entry: Any) -> None: + if hasattr(cache_entry, "clear_transients"): + cache_entry.clear_transients() + return + if hasattr(cache_entry, "_armed"): + cache_entry._armed = False + if hasattr(cache_entry, "_tape"): + cache_entry._tape = None + if hasattr(cache_entry, "_snapshot"): + cache_entry._snapshot = None + + def restore_after_acceptance( + self, + cache_entries: list[Any], + *, + target_len: int, + acceptance_length: int, + drafted_tokens: int = 0, + ) -> int: + replay_ns_total = 0 + fully_accepted = acceptance_length == drafted_tokens + for cache_entry in cache_entries: + if hasattr(cache_entry, "rollback"): + if fully_accepted: + self.clear_rollback_state(cache_entry) + continue + replay_start_ns = time.perf_counter_ns() + cache_entry.rollback(acceptance_length) + replay_ns_total += time.perf_counter_ns() - replay_start_ns + elif hasattr(cache_entry, "trim"): + offset = int(getattr(cache_entry, "offset", 0) or 0) + if offset > target_len: + replay_start_ns = time.perf_counter_ns() + cache_entry.trim(offset - target_len) + replay_ns_total += time.perf_counter_ns() - replay_start_ns + elif hasattr(cache_entry, "offset"): + offset = int(getattr(cache_entry, "offset", 0) or 0) + if offset > target_len: + cache_entry.offset = target_len + elif hasattr(cache_entry, "crop"): + cache_entry.crop(target_len) + return replay_ns_total + + def cleanup_generation_caches( + self, + target_cache: list[Any], + draft_cache: list[Any], + ) -> None: + for cache_entry in target_cache: + if hasattr(cache_entry, "clear_transients"): + cache_entry.clear_transients() + draft_cache.clear() + target_cache.clear() diff --git a/dflash_mlx/engine/target_ops.py b/dflash_mlx/engine/target_ops.py index 7bd9455..658b1a1 100644 --- a/dflash_mlx/engine/target_ops.py +++ b/dflash_mlx/engine/target_ops.py @@ -102,6 +102,7 @@ def cleanup_generation_caches( TARGET_BACKENDS = [ "dflash_mlx.engine.target_qwen_gdn:QwenGdnTargetOps", "dflash_mlx.engine.target_gemma4:Gemma4TargetOps", + "dflash_mlx.engine.target_lfm2:Lfm2TargetOps", ] def _load_backend_class(path: str) -> type[TargetOps]: diff --git a/dflash_mlx/model.py b/dflash_mlx/model.py index c3e6816..a87ca68 100644 --- a/dflash_mlx/model.py +++ b/dflash_mlx/model.py @@ -245,6 +245,20 @@ class DFlashDraftModelArgs: @classmethod def from_dict(cls, params: dict[str, Any]) -> "DFlashDraftModelArgs": data = dict(params) + rope_parameters = data.pop("rope_parameters", None) or {} + if "rope_theta" not in data and "rope_theta" in rope_parameters: + data["rope_theta"] = rope_parameters["rope_theta"] + if "rope_scaling" not in data: + scaling = { + key: value + for key, value in rope_parameters.items() + if key not in ("rope_theta", "rope_type") + } + rope_type = rope_parameters.get("rope_type") + if rope_type and rope_type != "default": + scaling["type"] = rope_type + if scaling: + data["rope_scaling"] = scaling layer_types = tuple(data.get("layer_types") or ()) if not layer_types and "num_hidden_layers" in data: layer_types = _default_draft_layer_types( diff --git a/dflash_mlx/short_conv_rollback_cache.py b/dflash_mlx/short_conv_rollback_cache.py new file mode 100644 index 0000000..5c34e6b --- /dev/null +++ b/dflash_mlx/short_conv_rollback_cache.py @@ -0,0 +1,123 @@ +# Copyright 2026 bstnxbt +# Licensed under the Apache License, Version 2.0 - see LICENSE file +# Based on DFlash (arXiv:2602.06036) + +from __future__ import annotations + +from typing import Any + +import mlx.core as mx +from mlx_lm.models.cache import _BaseCache + + +class ShortConvRollbackCache(_BaseCache): + def __new__(cls, *args, **kwargs): + instance = super().__new__(cls) + instance.left_padding = None + instance.lengths = None + instance._armed = False + instance._tape = None + instance._snapshot = None + return instance + + def __init__(self, kernel_size: int): + self.cache = [None] + self.kernel_size = int(kernel_size) + + def __getitem__(self, idx: int): + return self.cache[idx] + + def __setitem__(self, idx: int, value: Any) -> None: + self.cache[idx] = value + + @property + def state(self): + return self.cache + + @state.setter + def state(self, value) -> None: + self.cache = value + + def filter(self, batch_indices): + self.cache = [c[batch_indices] if c is not None else None for c in self.cache] + if self.lengths is not None: + self.lengths = self.lengths[batch_indices] + + def extend(self, other): + def cat(lhs, rhs): + if lhs is None: + return rhs + if rhs is None: + return lhs + return mx.concatenate([lhs, rhs]) + + self.cache = [cat(lhs, rhs) for lhs, rhs in zip(self.cache, other.cache, strict=True)] + + def extract(self, idx): + cache = ShortConvRollbackCache(self.kernel_size) + cache.cache = [c[idx : idx + 1] if c is not None else None for c in self.cache] + return cache + + def prepare(self, lengths=None, **kwargs): + self.lengths = None if lengths is None else mx.array(lengths) + + def finalize(self): + self.lengths = None + self.left_padding = None + self.clear_transients() + + def advance(self, n: int): + if self.lengths is not None: + self.lengths -= n + if self.left_padding is not None: + self.left_padding -= n + + def make_mask(self, n: int): + if self.left_padding is not None: + pos = mx.arange(n) + return pos >= self.left_padding[:, None] + if self.lengths is not None: + pos = mx.arange(n) + return pos < self.lengths[:, None] + return None + + def empty(self): + return self.cache[0] is None + + @property + def nbytes(self): + return sum(c.nbytes for c in self.cache if c is not None) + + def clear_transients(self) -> None: + self._armed = False + self._tape = None + self._snapshot = None + + def arm_rollback(self, prefix_len: int = 0) -> None: + del prefix_len + self._armed = True + self._tape = None + self._snapshot = list(self.cache) + + def record_tape_if_armed(self, bx_extended: mx.array) -> None: + if self._armed: + self._tape = mx.contiguous(bx_extended) + + def rollback(self, n_accepted: int) -> None: + if self._snapshot is None: + self.clear_transients() + return + n_keep = self.kernel_size - 1 + if self._tape is None or n_keep <= 0: + self.cache = list(self._snapshot) + self.clear_transients() + return + accepted_steps = int(n_accepted) + 1 + start = accepted_steps + end = start + n_keep + tape_len = int(self._tape.shape[1]) + if end > tape_len: + self.cache = list(self._snapshot) + else: + self.cache[0] = mx.contiguous(self._tape[:, start:end, :]) + self.clear_transients()