From d3fd3eda13d857f851d7745259dcab4faea6252c Mon Sep 17 00:00:00 2001 From: Hunter-Wrynn <2792843553@qq.com> Date: Sat, 7 Feb 2026 13:10:57 -0500 Subject: [PATCH] add sdar --- .env.example | 3 +- configs/gen_args.py | 8 + configs/generation/sdar.yaml | 22 + configs/model/sdar-8b-chat.yaml | 10 + configs/model/sdar-common.yaml | 7 + eval.py | 4 +- src/generation/sdar.py | 325 +++++ src/models/__init__.py | 15 +- src/models/sdar/__init__.py | 11 + src/models/sdar/configuration_sdar.py | 212 +++ src/models/sdar/eval_model.py | 77 + .../fused_linear_diffusion_cross_entropy.py | 682 +++++++++ src/models/sdar/modeling_sdar.py | 1251 +++++++++++++++++ src/utils/models.py | 16 +- 14 files changed, 2638 insertions(+), 5 deletions(-) create mode 100644 configs/generation/sdar.yaml create mode 100644 configs/model/sdar-8b-chat.yaml create mode 100644 configs/model/sdar-common.yaml create mode 100644 src/generation/sdar.py create mode 100644 src/models/sdar/__init__.py create mode 100644 src/models/sdar/configuration_sdar.py create mode 100644 src/models/sdar/eval_model.py create mode 100644 src/models/sdar/fused_linear_diffusion_cross_entropy.py create mode 100644 src/models/sdar/modeling_sdar.py diff --git a/.env.example b/.env.example index 44ebcfc..39b8a90 100644 --- a/.env.example +++ b/.env.example @@ -5,4 +5,5 @@ LLADA_BASE_PATH=GSAI-ML/LLaDA-8B-Base LLADA_INST_PATH=GSAI-ML/LLaDA-8B-Instruct LLADA_1_5_PATH=GSAI-ML/LLaDA-1.5 DREAM_BASE_PATH=Dream-org/Dream-v0-Base-7B -DREAM_INST_PATH=Dream-org/Dream-v0-Instruct-7B \ No newline at end of file +DREAM_INST_PATH=Dream-org/Dream-v0-Instruct-7B +SDAR_8B_CHAT_PATH=JetLM/SDAR-8B-Chat diff --git a/configs/gen_args.py b/configs/gen_args.py index 199848b..7095a54 100644 --- a/configs/gen_args.py +++ b/configs/gen_args.py @@ -149,6 +149,14 @@ def get_generation_args(task: str, model: str, cache: str | None = None): match model: case "dream-base" | "dream-inst": top_p = 0.9 + case model if model.startswith("sdar"): + # SDAR block diffusion defaults (see SDAR repo `generate.py`) + block_length = 4 + # keep `steps=gen_length` so that per-block denoising steps can be derived as: + # denoising_steps = steps // (gen_length // block_length) == block_length + temperature = 1.0 + top_p = 0.95 + top_k = 50 return GenerationArgs( gen_length=gen_length, diff --git a/configs/generation/sdar.yaml b/configs/generation/sdar.yaml new file mode 100644 index 0000000..e3462f0 --- /dev/null +++ b/configs/generation/sdar.yaml @@ -0,0 +1,22 @@ +# SDAR block-diffusion decoding. +# Most defaults (gen_length/block_length/steps/top_p, etc.) are filled by `configs/gen_args.py` +strategy: sdar + +# SDAR remasking / unmasking strategies (from SDAR repo `generate.py`) +remasking_strategy: low_confidence_dynamic +confidence_threshold: 0.85 +eb_threshold: 0.35 + +alg: "maskgit_plus" +gen_length: null +block_length: null +steps: null +temperature: 1.0 +top_p: 0.95 +top_k: 50 + +# Stop when the first EOS is generated; remaining masks (if any) are replaced with EOS. +stop_until_eot: true + +output_probs: false + diff --git a/configs/model/sdar-8b-chat.yaml b/configs/model/sdar-8b-chat.yaml new file mode 100644 index 0000000..96b34e9 --- /dev/null +++ b/configs/model/sdar-8b-chat.yaml @@ -0,0 +1,10 @@ +defaults: + - sdar-common + - _self_ + +name: sdar-8b-chat +path: ${oc.env:SDAR_8B_CHAT_PATH} + +# SDAR is a chat model; let lm-eval apply the chat template. +apply_chat_template: true + diff --git a/configs/model/sdar-common.yaml b/configs/model/sdar-common.yaml new file mode 100644 index 0000000..7033c73 --- /dev/null +++ b/configs/model/sdar-common.yaml @@ -0,0 +1,7 @@ +generation: + # JetLM/SDAR-8B-Chat config.json: + # bos_token_id/eos_token_id: 151643, mask_token_id: 151669 + mask_token_id: 151669 + eot_token_id: 151643 + pad_token_id: 151643 + diff --git a/eval.py b/eval.py index a077742..be0a6cb 100644 --- a/eval.py +++ b/eval.py @@ -63,7 +63,9 @@ def main(cfg: DictConfig) -> None: use_cache=( os.path.join(output_dir, "response") if cfg.use_eval_cache else None ), - apply_chat_template=cfg.model.name.endswith("inst"), + apply_chat_template=cfg.model.get( + "apply_chat_template", cfg.model.name.endswith("inst") + ), **overwrite_eval_task(cfg), ) diff --git a/src/generation/sdar.py b/src/generation/sdar.py new file mode 100644 index 0000000..d5c10a9 --- /dev/null +++ b/src/generation/sdar.py @@ -0,0 +1,325 @@ +import math +import torch +import torch.nn.functional as F + +from transformers.cache_utils import DynamicCache + +from src.frame import DecodeRecord, Frame, FrameDelta +from src.utils import register + + +def _top_k_logits(logits: torch.Tensor, k: int) -> torch.Tensor: + if k <= 0: + return logits + values, _ = torch.topk(logits, k) + min_values = values[..., -1, None] + return torch.where( + logits < min_values, torch.full_like(logits, float("-inf")), logits + ) + + +def _top_p_logits(logits: torch.Tensor, p: float) -> torch.Tensor: + if p >= 1.0: + return logits + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_mask = cumulative_probs > p + sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() + sorted_mask[..., 0] = False + mask_indices = torch.scatter( + torch.full_like(logits, False, dtype=torch.bool), -1, sorted_indices, sorted_mask + ) + return logits.masked_fill(mask_indices, float("-inf")) + + +def _sample_tokens( + logits: torch.Tensor, + *, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + logits: (B, L, V) + Returns: + tokens: (B, L) + token_probs: (B, L) probability of selected tokens + probs: (B, L, V) full distribution (float32) + """ + if temperature == 0.0: + tokens = torch.argmax(logits, dim=-1) + probs = torch.softmax(logits.to(torch.float32), dim=-1) + token_probs = torch.gather(probs, -1, tokens.unsqueeze(-1)).squeeze(-1) + return tokens, token_probs, probs + + scaled = logits / temperature + if top_k > 0: + scaled = _top_k_logits(scaled, top_k) + if top_p < 1.0: + scaled = _top_p_logits(scaled, top_p) + + probs = torch.softmax(scaled.to(torch.float32), dim=-1) + tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1).view( + probs.size(0), probs.size(1) + ) + token_probs = torch.gather(probs, -1, tokens.unsqueeze(-1)).squeeze(-1) + return tokens, token_probs, probs + + +def _get_num_transfer_tokens(block_length: int, denoising_steps: int) -> torch.Tensor: + base = block_length // denoising_steps + remainder = block_length % denoising_steps + num_transfer_tokens = torch.full((denoising_steps,), base, dtype=torch.long) + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + +@torch.no_grad() +def _block_diffusion_generate_one( + model, + *, + prompt_ids: torch.Tensor, + mask_id: int, + eos_token_id: int, + gen_length: int, + block_length: int, + denoising_steps: int, + temperature: float, + top_k: int, + top_p: float, + remasking_strategy: str, + confidence_threshold: float, + eb_threshold: float | None, + stop_until_eot: bool, +) -> torch.Tensor: + device = prompt_ids.device + prompt_length = int(prompt_ids.numel()) + num_blocks = math.ceil((prompt_length + gen_length) / block_length) + total_length = num_blocks * block_length + + block_mask = torch.tril( + torch.ones((num_blocks, num_blocks), device=device, dtype=torch.float32) + ) + # (1, total_length, total_length) + attention_mask = ( + block_mask.repeat_interleave(block_length, dim=0) + .repeat_interleave(block_length, dim=1) + .unsqueeze(0) + ) + position_ids = torch.arange(total_length, device=device).unsqueeze(0) + + x = torch.full((1, total_length), mask_id, dtype=torch.long, device=device) + x[:, :prompt_length] = prompt_ids.unsqueeze(0) + + past_key_values = DynamicCache() + + prefill_blocks = prompt_length // block_length + prefill_length = prefill_blocks * block_length + + if prefill_length > 0: + cur_x = x[:, :prefill_length] + cur_attn_mask = attention_mask[:, :prefill_length, :prefill_length] + cur_position_ids = position_ids[:, :prefill_length] + model( + cur_x, + attention_mask=cur_attn_mask, + position_ids=cur_position_ids, + past_key_values=past_key_values, + use_cache=True, + store_kv=True, + ) + + if denoising_steps <= 0: + raise ValueError(f"{denoising_steps=} must be > 0 for SDAR decoding.") + num_transfer_tokens = _get_num_transfer_tokens(block_length, denoising_steps).to( + device + ) + + for num_block in range(prefill_blocks, num_blocks): + block_start = num_block * block_length + block_end = (num_block + 1) * block_length + + cur_x = x[:, block_start:block_end].clone() + cur_attn_mask = attention_mask[:, block_start:block_end, :block_end] + cur_position_ids = position_ids[:, block_start:block_end] + + for step in range(denoising_steps + 1): + mask_index = cur_x == mask_id + if int(mask_index.sum()) == 0: + model( + cur_x, + attention_mask=cur_attn_mask, + position_ids=cur_position_ids, + past_key_values=past_key_values, + use_cache=True, + store_kv=True, + ) + break + + logits = model( + cur_x, + attention_mask=cur_attn_mask, + position_ids=cur_position_ids, + past_key_values=past_key_values, + use_cache=True, + store_kv=False, + ).logits + + sampled, sampled_p, probs = _sample_tokens( + logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + if step >= denoising_steps: + raise RuntimeError( + "SDAR decoding did not finish within denoising_steps; " + "consider increasing `steps` or reducing `gen_length`." + ) + k = int(num_transfer_tokens[step].item()) + + if remasking_strategy == "sequential": + transfer_index = torch.zeros_like(sampled, dtype=torch.bool) + for j in range(cur_x.size(0)): + if not mask_index[j].any(): + continue + first = int(mask_index[j].nonzero(as_tuple=True)[0].min().item()) + transfer_index[j, first : first + k] = True + + elif remasking_strategy == "low_confidence_static": + confidence = torch.where(mask_index, sampled_p, -torch.inf) + _, idx = torch.topk(confidence, k, dim=1) + transfer_index = torch.zeros_like(sampled, dtype=torch.bool).scatter_( + 1, idx, True + ) + + elif remasking_strategy == "low_confidence_dynamic": + confidence = torch.where(mask_index, sampled_p, -torch.inf) + transfer_index = torch.zeros_like(sampled, dtype=torch.bool) + for j in range(confidence.size(0)): + high_conf_mask = confidence[j] > confidence_threshold + if int(high_conf_mask.sum()) >= k: + transfer_index[j] = high_conf_mask + else: + _, idx = torch.topk(confidence[j], k) + transfer_index[j, idx] = True + + elif remasking_strategy == "entropy_bounded": + if eb_threshold is None: + raise ValueError( + "eb_threshold must be provided when remasking_strategy='entropy_bounded'." + ) + eps = 1e-12 + entropies = -( + probs.clamp_min(eps) * probs.clamp_min(eps).log() + ).sum(dim=-1) + entropies = torch.where(mask_index, entropies, torch.inf) + ent_sorted, order = torch.sort(entropies, dim=1, descending=False) + cumsum = torch.cumsum(ent_sorted, dim=1) + transfer_index = torch.zeros_like(sampled, dtype=torch.bool) + for j in range(cur_x.size(0)): + t = torch.tensor(eb_threshold, device=device, dtype=cumsum.dtype) + kk = int(torch.searchsorted(cumsum[j], t, right=False).item()) + kk = max(1, min(kk, int(mask_index[j].sum().item()))) + selected = order[j, :kk] + transfer_index[j, selected] = True + + else: + raise ValueError(f"Unknown remasking_strategy: {remasking_strategy!r}") + + cur_x[transfer_index] = sampled[transfer_index] + + x[:, block_start:block_end] = cur_x + + if stop_until_eot and (x[:, prompt_length:] == eos_token_id).any(): + break + + out = x[:, prompt_length : prompt_length + gen_length].clone() + out[out == mask_id] = eos_token_id + return out.squeeze(0) + + +@register.gen_strategy("sdar") +@torch.no_grad() +def sdar_generate( + model, + input_ids: torch.Tensor, + *, + attention_mask: torch.Tensor | None = None, + gen_length: int, + block_length: int, + steps: int, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 1.0, + remasking_strategy: str = "low_confidence_dynamic", + confidence_threshold: float = 0.85, + eb_threshold: float | None = 0.35, + stop_until_eot: bool = True, + mask_token_id: int | None = None, + eot_token_id: int | None = None, + **_: object, +) -> DecodeRecord: + """ + SDAR block-diffusion decoding (ported from SDAR repo `generate.py`). + Notes: + - Supports left-padded `input_ids` using `attention_mask` to crop the prompt. + - `steps` is the *total* denoising steps across blocks; per-block denoising steps are derived as + `denoising_steps = steps // (gen_length // block_length)` (matching the constraints in `configs/gen_args.py`). + """ + if input_ids.dim() != 2: + raise ValueError(f"Expected input_ids shape (B, L), got {tuple(input_ids.shape)}") + batch_size = input_ids.size(0) + + if gen_length <= 0 or block_length <= 0 or steps <= 0: + raise ValueError(f"Invalid lengths: {gen_length=}, {block_length=}, {steps=}") + if gen_length % block_length != 0: + raise ValueError(f"{gen_length=} must be divisible by {block_length=}") + num_gen_blocks = gen_length // block_length + if steps % num_gen_blocks != 0: + raise ValueError(f"{steps=} must be divisible by num_gen_blocks={num_gen_blocks}") + denoising_steps = steps // num_gen_blocks + + mask_id = int(mask_token_id if mask_token_id is not None else getattr(model.config, "mask_token_id")) + eos_id = int(eot_token_id if eot_token_id is not None else getattr(model.config, "eos_token_id")) + + outputs: list[torch.Tensor] = [] + for i in range(batch_size): + if attention_mask is not None: + prompt_len = int(attention_mask[i].sum().item()) + prompt_ids = input_ids[i, -prompt_len:] + else: + prompt_ids = input_ids[i] + + outputs.append( + _block_diffusion_generate_one( + model, + prompt_ids=prompt_ids, + mask_id=mask_id, + eos_token_id=eos_id, + gen_length=gen_length, + block_length=block_length, + denoising_steps=denoising_steps, + temperature=temperature, + top_k=top_k, + top_p=top_p, + remasking_strategy=remasking_strategy, + confidence_threshold=confidence_threshold, + eb_threshold=eb_threshold, + stop_until_eot=stop_until_eot, + ) + ) + + decoded_tokens = torch.stack(outputs, dim=0) + initial_frame = Frame.create_initial_frame( + input_ids, gen_length=gen_length, mask_token_id=mask_id + ) + transfer_index = tuple( + torch.arange(gen_length, device=input_ids.device) for _ in range(batch_size) + ) + delta = FrameDelta(transfer_index=transfer_index, decoded_tokens=decoded_tokens) + record = DecodeRecord(initial_frame=initial_frame, deltas=[delta], block_length=block_length) + return record + diff --git a/src/models/__init__.py b/src/models/__init__.py index 450cef1..0170db7 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,2 +1,15 @@ from .dream import DreamModel, DreamConfig, DreamEval -from .llada import LLaDAModelLM, LLaDAConfig, LLaDAEval \ No newline at end of file +from .llada import LLaDAModelLM, LLaDAConfig, LLaDAEval +from .sdar import SDARConfig, SDARForCausalLM, SDAREval + +__all__ = [ + "DreamModel", + "DreamConfig", + "DreamEval", + "LLaDAModelLM", + "LLaDAConfig", + "LLaDAEval", + "SDARConfig", + "SDARForCausalLM", + "SDAREval", +] diff --git a/src/models/sdar/__init__.py b/src/models/sdar/__init__.py new file mode 100644 index 0000000..0c5c46e --- /dev/null +++ b/src/models/sdar/__init__.py @@ -0,0 +1,11 @@ +from .configuration_sdar import SDARConfig +from .modeling_sdar import SDARForCausalLM, SDARModel, SDARPreTrainedModel +from .eval_model import SDAREval + +__all__ = [ + "SDARConfig", + "SDARPreTrainedModel", + "SDARModel", + "SDARForCausalLM", + "SDAREval", +] diff --git a/src/models/sdar/configuration_sdar.py b/src/models/sdar/configuration_sdar.py new file mode 100644 index 0000000..726be41 --- /dev/null +++ b/src/models/sdar/configuration_sdar.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SDAR model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class SDARConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SDARModel`]. It is used to instantiate a + SDAR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + SDAR-1.7B [DiffuOpen/SDAR-1.7B-Chat](https://huggingface.co/DiffuOpen/SDAR-1.7B-Chat/). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the SDAR model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`SDARModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import SDARModel, SDARConfig + + >>> # Initializing a SDAR style configuration + >>> configuration = SDARConfig() + + >>> # Initializing a model from the SDAR-8B style configuration + >>> model = SDARModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "sdar" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `SDAR` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["SDARConfig"] diff --git a/src/models/sdar/eval_model.py b/src/models/sdar/eval_model.py new file mode 100644 index 0000000..ac0ea8c --- /dev/null +++ b/src/models/sdar/eval_model.py @@ -0,0 +1,77 @@ +import torch +import torch.nn.functional as F + +from omegaconf import DictConfig +from tqdm import tqdm + +from ..eval_mdlm import EvalMDLM + + +class SDAREval(EvalMDLM): + """ + Evaluation wrapper for SDAR (Synergy of Diffusion and AutoRegression). + + - Generation is handled by `src/generation/sdar.py` (strategy name: "sdar"). + - Loglikelihood uses standard left-to-right causal LM scoring, so lm-eval + tasks that rely on loglikelihood can run. + """ + + def __init__(self, cfg: DictConfig, **kwargs): + super().__init__(cfg, **kwargs) + + # Ensure tokenizer exposes mask_token_id for Frame decoding utilities. + if not hasattr(self.tokenizer, "mask_token_id") or self.tokenizer.mask_token_id is None: # type: ignore[attr-defined] + setattr(self.tokenizer, "mask_token_id", cfg.generation.mask_token_id) + + def _encode_pair(self, context: str, continuation: str) -> tuple[list[int], list[int]]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + context_enc: list[int] = self.tokenizer(context).input_ids + continuation_enc: list[int] = self.tokenizer(continuation).input_ids + + if len(context_enc) == 0: + # For empty context, prepend the model prefix/BOS token so the first continuation token is scoreable. + context_enc = [int(self.prefix_token_id)] + + return context_enc, continuation_enc + + @torch.no_grad() + def loglikelihood(self, requests, disable_tqdm: bool = False): + out = [] + for instance in tqdm( + requests, + desc="Computing likelihood...", + disable=disable_tqdm or not self.accelerator.is_main_process, + ): + context, continuation = self._encode_pair(*instance.args) + input_ids = torch.tensor([context + continuation], device=self.device) + + if input_ids.size(1) > self.cfg.max_length: + # Truncate on the left (keep tail, which contains continuation). + input_ids = input_ids[:, -self.cfg.max_length :] + # If we truncated into the continuation boundary, we can't reconstruct the exact split; + # fall back to scoring everything except the first token. + context_len = max(1, input_ids.size(1) - len(continuation)) + else: + context_len = len(context) + + logits = self.model(input_ids).logits # (1, L, V) + log_probs = F.log_softmax(logits[:, :-1].to(torch.float32), dim=-1) + target_ids = input_ids[:, 1:] + + cont_start = max(context_len - 1, 0) + cont_log_probs = log_probs[:, cont_start:] + cont_targets = target_ids[:, cont_start:] + + token_log_probs = cont_log_probs.gather(-1, cont_targets.unsqueeze(-1)).squeeze(-1) + logprob = token_log_probs.sum().item() + + greedy = cont_log_probs.argmax(dim=-1) + is_greedy = bool(torch.all(greedy == cont_targets).item()) + out.append((logprob, is_greedy)) + + return out + diff --git a/src/models/sdar/fused_linear_diffusion_cross_entropy.py b/src/models/sdar/fused_linear_diffusion_cross_entropy.py new file mode 100644 index 0000000..54845de --- /dev/null +++ b/src/models/sdar/fused_linear_diffusion_cross_entropy.py @@ -0,0 +1,682 @@ +# -*- coding: utf-8 -*- + +# Code adapted from +# https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/fused_linear_cross_entropy.py +# Implementation of element-wise division of cross entropy loss + + +# Code adapted from +# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py + +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module +from torch.distributed.tensor.parallel import ParallelStyle + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 +# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit +def logsumexp_fwd_kernel( + x, + z, + scale, + D: tl.constexpr, + B: tl.constexpr, + HAS_SCALE: tl.constexpr +): + i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + o_d = i_d * B + tl.arange(0, B) + m_d = o_d < D + + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) + if HAS_SCALE: + b_x = b_x * scale + b_m = tl.max(b_x, 0) + b_z = tl.log(tl.sum(tl.exp(b_x - b_m), 0)) + b_m + tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) + + +def logsumexp_fwd( + x, + scale: Optional[float] = None, + dtype: Optional[torch.dtype] = None +): + r""" + Compute the logsumexp of the input tensor over the last dimension. + + Args: + x (Tensor): + The input tensor of any shape. + scale (Optional[float]): + The scale applied to the input tensor. Default: `None`. + dtype (Optional[torch.dtype]): + The data type of the output tensor. Default: `None`. + Returns: + Tensor: The logsumexp of the input tensor. + """ + + shape = x.shape + x = x.view(-1, shape[-1]) + N, D = x.shape + B = min(triton.next_power_of_2(D), 64 * 1024) + ND = triton.cdiv(D, B) + + z = x.new_empty(N, ND, dtype=torch.float) + logsumexp_fwd_kernel[(N, ND)]( + x=x, + z=z, + scale=scale, + D=D, + B=B + ) + z = z.logsumexp(-1).view(*shape[:-1]) + if dtype is not None and dtype != torch.float: + z = z.to(dtype) + return z + +@triton.jit +def cross_entropy_kernel( + logits, + lse, + target, + p_mask, + loss, + total, + ignore_index, + label_smoothing: tl.constexpr, + logit_scale: tl.constexpr, + reduction: tl.constexpr, + V: tl.constexpr, + BV: tl.constexpr +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. + Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Args: + logits: + Pointer to logits tensor. + lse: + Pointer to logsumexp tensor. + target: Pointer to target tensor. + loss: + Pointer to tensor to store the loss. + V (int): + The number of columns in the input tensor. + total (int): + The number of non-ignored classes. + ignore_index (int): + The index to ignore in the target. + label_smoothing (float): + The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): + The string for the reduction to apply + BV (int): + The block size for vocab. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, i_n * stride will overflow out of int32, so we convert to int64 + i_n = tl.program_id(0).to(tl.int64) + NV = tl.cdiv(V, BV) + + # 1. Load target first because if the target is ignore_index, we can return right away + b_y = tl.load(target + i_n) + # load p_mask + b_p_mask = tl.load(p_mask + i_n) + + # 2. locate the start index + logits += i_n * V + + if b_y == ignore_index: + # set all x as 0 + for i in range(0, V, BV): + o_v = i + tl.arange(0, BV) + tl.store(logits + o_v, 0.0, mask=o_v < V) + return + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: compute logsumexp + # we did this in anouter kernel + b_l = tl.load(logits + b_y) * logit_scale + b_lse = tl.load(lse + i_n) + + # 4. Calculate the loss + # loss = lse - logits_l + # celoss = -log(q_y) = -log(softmax(x_y)) + b_loss = (b_lse - b_l) / b_p_mask # Diffusion Scaled '1/t' + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + b_z = 0.0 + eps = label_smoothing / V + + # We need tl.debug_barrier() as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + for iv in range(0, NV): + o_v = iv * BV + tl.arange(0, BV) + b_logits = tl.load(logits + o_v, mask=o_v < V, other=float('-inf')) * logit_scale + if label_smoothing > 0: + # scale X beforehand to avoid overflow + b_z += tl.sum(tl.where(o_v < V, -eps * b_logits, 0.0)) + b_p = (tl.exp(b_logits - b_lse) - eps) * logit_scale + b_p /= b_p_mask # 修改 + if reduction == "mean": + b_p = b_p / total + tl.store(logits + o_v, b_p, mask=o_v < V) + + tl.debug_barrier() + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: + # https://arxiv.org/pdf/1512.00567 + # pytorch: + # https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + b_loss = b_loss * (1 - label_smoothing) + (b_z + label_smoothing * b_lse) + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + b_l = tl.load(logits + b_y) + + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == 'mean': + b_loss = b_loss / total + # b_l += (label_smoothing - 1) / total * logit_scale + # b_l has already been divided by b_p_mask and total + b_l += (label_smoothing - 1) / b_p_mask / total * logit_scale + else: + # b_l += (label_smoothing - 1) * logit_scale + b_l += (label_smoothing - 1) / b_p_mask * logit_scale + + tl.store(loss + i_n, b_loss) + tl.store(logits + b_y, b_l) + + +@triton.jit +def elementwise_mul_kernel( + x, + g, + N: tl.constexpr, + B: tl.constexpr +): + """ + This function multiplies each element of the tensor pointed by x with the value pointed by g. + The multiplication is performed in-place on the tensor pointed by x. + + Parameters: + x: + Pointer to the input tensor. + g: + Pointer to the gradient output value. + N (int): + The number of columns in the input tensor. + B (int): + The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + i_x = tl.program_id(0).to(tl.int64) + o_x = i_x * B + tl.arange(0, B) + + # Load the gradient output value + b_g = tl.load(g) + b_x = tl.load(x + o_x, mask=o_x < N) + tl.store(x + o_x, b_x * b_g, mask=o_x < N) + + +def fused_linear_cross_entropy_forward( + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + p_mask: torch.Tensor = None, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" +): + device = x.device + # inputs have shape: [N, H] + # materialized activations will have shape: [N, V] + # the increase in memory = [N, V] + # reduction can be achieved by partitioning the number of tokens N into smaller chunks. + + # ideally, we would like to achieve the same memory consumption as [N, H], + # so the expected chunk size should be: + # NC = ceil(V / H) + # C = ceil(N / NC) + # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048 + N, H, V = *x.shape, weight.shape[0] + BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # TODO: in real cases, we may need to limit the number of chunks NC to + # ensure the precisions of accumulated gradients + NC = min(num_chunks, triton.cdiv(V, H)) + C = triton.next_power_of_2(triton.cdiv(N, NC)) + NC = triton.cdiv(N, C) + + # [N, H] + dx = torch.zeros_like(x, device=device) + # [V, H] + dw = torch.zeros_like(weight, device=device, dtype=torch.float) if weight is not None else None + # [V] + db = torch.zeros_like(bias, device=device, dtype=torch.float) if bias is not None else None + # [N] + loss = torch.zeros(N, device=device, dtype=torch.float) + + total = target.ne(ignore_index).sum().item() + + for ic in range(NC): + start, end = ic * C, min((ic + 1) * C, N) + # [C, N] + c_x = x[start:end] + # when doing matmul, use the original precision + # [C, V] + c_logits = F.linear(c_x, weight, bias) + c_target = target[start:end] + c_p_mask = p_mask[start:end] + # [C] + # keep lse in fp32 to maintain precision + c_lse = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float) + + # unreduced loss + c_loss = loss[start:end] + + # Here we calculate the gradient of c_logits in place so we can save memory. + cross_entropy_kernel[(c_logits.shape[0],)]( + logits=c_logits, + lse=c_lse, + target=c_target, + p_mask=c_p_mask, + loss=c_loss, + total=total, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + logit_scale=logit_scale, + reduction=reduction, + V=V, + BV=BV, + num_warps=32 + ) + + # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V + # thus dx should be of shape: C x H + dx[start:end] = torch.mm(c_logits, weight) + + # keep dw in fp32 to maintain precision + if weight is not None: + dw += c_logits.t() @ c_x + + if bias is not None: + torch.add(input=db, other=c_logits.sum(0), out=db) + + loss = loss.sum() + if dw is not None: + dw = dw.to(weight) + if db is not None: + db = db.to(bias) + return loss, dx, dw, db + + +def fused_linear_cross_entropy_backward( + do: torch.Tensor, + dx: torch.Tensor, + dw: torch.Tensor, + db: torch.Tensor +): + # If cross entropy is the last layer, do is 1.0. Skip the mul to save time + if torch.ne(do, torch.tensor(1.0, device=do.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + N, H = dx.shape + B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + elementwise_mul_kernel[(triton.cdiv(N * H, B),)]( + x=dx, + g=do, + N=N*H, + B=B, + num_warps=32, + ) + + # handle dw + if dw is not None: + V, H = dw.shape + elementwise_mul_kernel[(triton.cdiv(V * H, B),)]( + x=dw, + g=do, + N=V*H, + B=B, + num_warps=32, + ) + + if db is not None: + V = db.shape[0] + elementwise_mul_kernel[(triton.cdiv(V, B),)]( + x=db, + g=do, + N=V, + B=B, + num_warps=32, + ) + return dx, dw, db + + +class FusedLinearCrossEntropyFunction(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + p_mask: torch.Tensor = None, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the x and target + for the backward pass. + + x (torch.Tensor): [batch_size * seq_len, hidden_size] + target (torch.LongTensor): [batch_size * seq_len] + where each value is in [0, vocab_size). + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + bias (Optional[torch.Tensor]): [vocab_size] + where `vocab_size` is the number of classes. + p_mask(torch.Tensor): [batch_size * seq_len] + Its shape should be same as target. + ignore_index: + the index to ignore in the target. + label_smoothing: + the amount of smoothing when computing the loss, where 0.0 means no smoothing. + logit_scale: float = 1.0, + A scaling factor applied to the logits. Default: 1.0 + num_chunks: int + The number of chunks to split the input tensor into for processing. + This can help optimize memory usage and computation speed. + Default: 8 + reduction: + Specifies the reduction to apply to the output: 'mean' | 'sum'. + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + Default: 'mean'. + """ + loss, dx, dw, db = fused_linear_cross_entropy_forward( + x, + target, + weight, + bias, + p_mask, + ignore_index, + label_smoothing, + logit_scale, + num_chunks, + reduction + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + dx.detach(), + dw.detach() if weight is not None else None, + db.detach() if bias is not None else None, + ) + return loss + + @staticmethod + def backward(ctx, do): + dx, dw, db = ctx.saved_tensors + dx, dw, db = fused_linear_cross_entropy_backward(do, dx, dw, db) + # 10 gradients should be returned, with `p_mask` having no grads + # Check the number of arguments in the `forward` method + return dx, None, dw, db, None, None, None, None, None, None + + +def fused_linear_cross_entropy_loss( + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + p_mask: torch.Tensor = None, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x (torch.Tensor): [batch_size * seq_len, hidden_size] + target (torch.LongTensor): [batch_size * seq_len] + where each value is in [0, vocab_size). + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + bias (Optional[torch.Tensor]): [vocab_size] + where `vocab_size` is the number of classes. + p_mask(torch.Tensor): [batch_size * seq_len] + Its shape should be same as target. + ignore_index: int. + If target == ignore_index, the loss is set to 0.0. + label_smoothing: float + logit_scale: float + A scaling factor applied to the logits. Default: 1.0 + num_chunks: int + The number of chunks to split the input tensor into for processing. + This can help optimize memory usage and computation speed. + Default: 8 + reduction: + Specifies the reduction to apply to the output: 'mean' | 'sum'. + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + Default: 'mean'. + Returns: + losses: [batch,], float + """ + return FusedLinearCrossEntropyFunction.apply( + x, + target, + weight, + bias, + p_mask, + ignore_index, + label_smoothing, + logit_scale, + num_chunks, + reduction + ) + + +class FusedLinearDiffusionCrossEntropyLoss(nn.Module): + + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" + ): + """ + Args: + ignore_index: int. + If target == ignore_index, the loss is set to 0.0. + label_smoothing: float + logit_scale: float + A scaling factor applied to the logits. Default: 1.0 + num_chunks: int + The number of chunks to split the input tensor into for processing. + This can help optimize memory usage and computation speed. + Default: 8 + reduction: + Specifies the reduction to apply to the output: 'mean' | 'sum'. + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + Default: 'mean'. + """ + super().__init__() + + assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported" + + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.num_chunks = num_chunks + self.reduction = reduction + + @torch.compiler.disable + def forward( + self, + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + p_mask: torch.Tensor = None + ): + """ + Args: + x (torch.Tensor): [batch_size, seq_len, hidden_size] + target (torch.LongTensor): [batch_size, seq_len] + where each value is in [0, V). + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + bias (Optional[torch.Tensor]): [vocab_size] + where `vocab_size` is the number of classes. + p_mask(torch.Tensor): [batch_size, seq_len] + Its shape is same as target. + Shape: (1, packed_length) when varlen attn is used. + Returns: + loss + + TODO: + follow https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training + ```py + unreduced_loss /= p_mask + ``` + Scale the values of `unreduced_loss at different positions + """ + if p_mask is None: + p_mask = torch.ones_like(target, dtype=torch.float, device=x.device) + + x = x.contiguous().view(-1, x.shape[-1]) + target = target.contiguous().view(-1) + weight = weight.contiguous() + bias = bias.contiguous() if bias else None + p_mask = p_mask.contiguous().view(-1) + l, d = x.shape + assert l == target.shape[0] == p_mask.shape[0], f"{x.shape=}, {target.shape=}, {p_mask.shape=}" + + loss = fused_linear_cross_entropy_loss( + x, + target, + weight=weight, + bias=bias, + p_mask=p_mask, + ignore_index=self.ignore_index, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + num_chunks=self.num_chunks, + reduction=self.reduction + ) + return loss + + +class LinearLossParallel(ParallelStyle): + def __init__( + self, + *, + sequence_dim: int = 1, + use_local_output: bool = False, + ): + super().__init__() + + self.sequence_sharding = (Shard(sequence_dim),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + x, target, weight, bias = inputs + + if not isinstance(x, DTensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + x = DTensor.from_local(x, device_mesh, sequence_sharding) + if x.placements != sequence_sharding: + x = x.redistribute(placements=sequence_sharding, async_op=True) + if not isinstance(target, DTensor): + target = DTensor.from_local(target, device_mesh, [Replicate()]) + if target.placements != sequence_sharding: + target = target.redistribute(placements=sequence_sharding, async_op=True) + + if not isinstance(weight, DTensor): + weight = DTensor.from_local(weight, device_mesh, [Replicate()]) + if weight.placements != [Replicate()]: + # we replicate the weight/bias in FLCE + weight = weight.redistribute(placements=[Replicate()], async_op=True) + + if bias is not None and not isinstance(bias, DTensor): + bias = DTensor.from_local(bias, device_mesh, [Replicate()]) + if bias is not None and bias.placements != [Replicate()]: + bias = bias.redistribute(placements=[Replicate()], async_op=True) + + return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias + + @staticmethod + def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=None, + input_fn=partial(self._prepare_input_fn, self.sequence_sharding), + output_fn=partial(self._prepare_output_fn, self.use_local_output) + ) diff --git a/src/models/sdar/modeling_sdar.py b/src/models/sdar/modeling_sdar.py new file mode 100644 index 0000000..9b59479 --- /dev/null +++ b/src/models/sdar/modeling_sdar.py @@ -0,0 +1,1251 @@ +# This file is modified based on https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3/modeling_qwen3.py. +# +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple, Union, List + +import torch +from torch import nn +try: + from einops import rearrange +except ImportError as e: # pragma: no cover + raise ImportError( + "SDAR requires `einops`. Install it via `pip install einops`." + ) from e + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from .configuration_sdar import SDARConfig +from .fused_linear_diffusion_cross_entropy import FusedLinearDiffusionCrossEntropyLoss + +try: + from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm +except Exception: # pragma: no cover + flash_rms_norm = None + +import torch.nn.functional as F +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except: + pass + +try: + from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401 + liger_kernel_is_available = True +except ImportError: + liger_kernel_is_available = False + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + from transformers.integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongTensor: + """ + 使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。 + 这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。 + 它会独立地处理 batch 中的每一行。 + + Args: + position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length). + + Returns: + 修改后的 position_ids Tensor, shape (batch_size, sequence_length). + """ + if position_ids.dim() != 2: + raise ValueError(f"Input tensor must be 2D, but got {position_ids.dim()} dimensions.") + + batch_size, seq_len = position_ids.shape + device = position_ids.device + + col_indices = torch.arange(seq_len, device=device, dtype=position_ids.dtype).expand(batch_size, -1) + mask = (position_ids != 0) + + masked_indices = col_indices * mask + last_nonzero_idx = torch.max(masked_indices, dim=1).values + has_nonzero = torch.any(mask, dim=1) + pad_start_idx = torch.where(has_nonzero, last_nonzero_idx + 1, torch.tensor(0, device=device, dtype=position_ids.dtype)) + + padding_mask = col_indices >= pad_start_idx.unsqueeze(1) + new_pad_values = col_indices - pad_start_idx.unsqueeze(1) + position_ids = torch.where(padding_mask, new_pad_values, position_ids) + + return position_ids + + +def calculate_token_nums(position_ids: torch.Tensor): + """ + 使用 PyTorch 高效计算一个批次中每个打包序列的长度。 + + Args: + position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。 + 例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]]) + Returns: + list[list[int]]: 一个嵌套列表,包含每个批次项中各个序列的长度。 + 例如:[[5, 6, 4, 1, 1, 1]] + """ + # 检查输入是否为 2D Tensor + if position_ids.dim() != 2: + raise ValueError(f"输入必须是 2D Tensor,但得到了 {position_ids.dim()}D") + + all_lengths = [] + + # 我们按批次逐行处理。因为每行的序列长度数量不同(ragged), + # 所以 Python 循环在批次维度上是最高效且最清晰的写法。 + # 循环内部的操作是完全向量化的。 + for pids_row in position_ids: + # 获取当前行的总长度 + seq_len = pids_row.shape[0] + + # 1. 找到所有值为 0 的元素的索引 + # pids_row == 0 会返回一个布尔 Tensor: [True, False, ..., True, ...] + # torch.nonzero 会返回这些 True 值的索引 + # .flatten() 将其从 (N, 1) 形状的 Tensor 变为 (N,) 形状 + zero_indices = torch.nonzero(pids_row == 0).flatten() + + # 2. 将序列的总长度作为一个额外的切分点添加到末尾 + # 这对于计算最后一个序列的长度至关重要 + # 注意:要确保新创建的 tensor 和原始 tensor 在同一个设备上 (cpu/cuda) + split_points = torch.cat([ + zero_indices, + torch.tensor([seq_len], device=pids_row.device, dtype=zero_indices.dtype) + ]) + + # 3. 计算相邻切分点之间的差值,这就是我们想要的长度 + # torch.diff([a, b, c, d]) 会返回 [b-a, c-b, d-c] + lengths = torch.diff(split_points) + + all_lengths.append(lengths) + + return all_lengths + + +def forward_add_noise_packed( + inputs_ids: torch.Tensor, + num_tokens_list: List[torch.Tensor], + prompt_mask: torch.Tensor, + mask_id: int, + eps: float = 1e-3, + max_tries: int = 10, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 为一批打包(packed)序列的 token ID 添加噪声。 + + 此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。 + 它会随机将一部分 token 的 ID 替换为 mask_id。 + 这个过程会避开被 prompt_mask 标记的位置。 + + Args: + inputs_ids (torch.Tensor): + 输入的 token ID 张量,形状为 (bsz, total_tokens)。 + num_tokens_list (List[torch.Tensor]): + 一个张量列表,长度为 bsz。列表中的每个张量记录了对应批次项中 + 每个逻辑样本的长度。例如: [tensor([len1, len2]), tensor([len3, len4, len5])]. + prompt_mask (torch.Tensor): + 布尔型张量,形状为 (bsz, total_tokens),值为 True 的位置表示是 prompt, + 不应添加噪声。 + mask_id (int): + 用于替换的 mask token 的 ID。 + eps (float): + 微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。 + max_tries (int): + 为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。 + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - noisy_input_ids (torch.Tensor): + 添加噪声后的 token ID 张量,形状为 (bsz, total_tokens)。 + - final_masked_indices (torch.Tensor): + 布尔型张量,标记了哪些位置被实际 mask 了,形状为 (bsz, total_tokens)。 + - p_masks (torch.Tensor): + 一个一维张量,包含了被 mask 的 token 对应的实际噪声率。 + """ + # 1. 验证和获取形状 + bsz, total_tokens = inputs_ids.shape + device = inputs_ids.device + + # 检查输入的一致性 + assert len(num_tokens_list) == bsz, f"num_tokens_list 的长度 ({len(num_tokens_list)}) 必须等于 bsz ({bsz})" + assert prompt_mask.shape == (bsz, total_tokens), f"prompt_mask 形状不匹配, 期望 {(bsz, total_tokens)}, 得到 {prompt_mask.shape}" + + # 准备结果容器 + noisy_ids_list = [] + final_masked_indices_list = [] + p_masks_per_token_list = [] + + # 2. 在批次维度上迭代 + # 这是处理不同打包结构最直接有效的方法 + for i in range(bsz): + # 提取当前批次项的数据 + current_ids = inputs_ids[i:i+1] # shape: (1, total_tokens) + current_num_tokens = num_tokens_list[i] + current_prompt_mask = prompt_mask[i:i+1] # shape: (1, total_tokens) + + num_samples_in_item = len(current_num_tokens) + # 验证当前批次项的 token 总数是否匹配 + assert total_tokens == torch.sum(current_num_tokens), \ + f"批次项 {i} 的 num_tokens 之和 ({torch.sum(current_num_tokens)}) 与 total_tokens ({total_tokens}) 不匹配" + + eligible_for_masking = ~current_prompt_mask + + # 如果没有任何 token 可以被 mask,直接使用原始输入,并设置 p_mask 为 eps + if not eligible_for_masking.any(): + noisy_ids_list.append(current_ids) + final_masked_indices_list.append(torch.zeros_like(current_prompt_mask, dtype=torch.bool)) + # p_mask_per_token 的形状应为 (1, total_tokens) 以便后续拼接 + p_masks_per_token_list.append(torch.full((1, total_tokens), eps, device=device, dtype=torch.float)) + continue + + # --- 尝试生成 mask,确保至少 mask 一个 token --- + final_masked_indices_item = torch.zeros_like(current_prompt_mask, dtype=torch.bool) + p_mask_per_token = None + + for _ in range(max_tries): + # 为每个逻辑样本生成一个独立的噪声率 t + t = torch.rand(num_samples_in_item, device=device) + p_mask_per_sample = (1 - eps) * t + eps + + # 将每个样本的噪声率扩展到其所有 token 上 + p_mask_per_token_1d = torch.repeat_interleave(p_mask_per_sample, current_num_tokens) + p_mask_per_token = p_mask_per_token_1d.unsqueeze(0) # shape: (1, total_tokens) + + # 根据噪声率生成随机 mask + masked_indices = torch.rand_like(p_mask_per_token) < p_mask_per_token + # 应用 prompt mask,确保 prompt 不被 mask + final_masked_indices_item = masked_indices & eligible_for_masking + + # 如果成功 mask 了至少一个 token,则跳出尝试循环 + if final_masked_indices_item.any(): + break + + # 如果 max_tries 之后仍然没有 mask 任何 token (极小概率),就强制 mask 一个可 mask 的 token + if not final_masked_indices_item.any(): + eligible_indices = torch.nonzero(eligible_for_masking.squeeze(0), as_tuple=True)[0] + if len(eligible_indices) > 0: + # 随机选择一个可 mask 的位置 + random_choice = torch.randint(0, len(eligible_indices), (1,)).item() + force_mask_idx = eligible_indices[random_choice] + final_masked_indices_item[0, force_mask_idx] = True + + + # --- 根据最终的 mask 生成带噪声的 IDs --- + noisy_ids_item = torch.where( + final_masked_indices_item, + mask_id, + current_ids + ) + + # 保存这个批次项的结果 + noisy_ids_list.append(noisy_ids_item) + final_masked_indices_list.append(final_masked_indices_item) + p_masks_per_token_list.append(p_mask_per_token) + + # 3. 将列表中的结果堆叠成最终的批处理张量 + noisy_input_ids = torch.cat(noisy_ids_list, dim=0) + final_masked_indices = torch.cat(final_masked_indices_list, dim=0) + p_mask_full = torch.cat(p_masks_per_token_list, dim=0) + + # 4. 提取被 mask 位置对应的噪声率 + p_masks = p_mask_full[final_masked_indices] + + return noisy_input_ids, final_masked_indices, p_masks + + +def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None): + """ + Constructs the specialized block diffusion attention mask for training + composed of three masks: + - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks + - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context + - **Block Causal Mask (M_BC)**: Attention to update x0 + + Args: + b, h: Batch and head indices (ignored for mask logic). + q_idx, kv_idx: Query and Key indices. + seq_len: Total sequence length. + block_size: Defines the block structure. + + Returns: + A boolean attention mask. + """ + + # Indicate whether token belongs to xt or x0 + x0_flag_q = q_idx >= n + x0_flag_kv = kv_idx >= n + + # Compute block indices + block_q = torch.where( + x0_flag_q == 1, (q_idx - n) // block_size, q_idx // block_size + ) + block_kv = torch.where( + x0_flag_kv == 1, (kv_idx - n) // block_size, kv_idx // block_size + ) + + # **1. Block Diagonal Mask (M_BD) ** + block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv) + + # **2. Offset Block-Causal Mask (M_OBC) ** + offset_block_causal = (block_q > block_kv) & ( + x0_flag_kv == 1) & (x0_flag_q == 0) + + # **3. Block-Causal Mask (M_BC) ** + block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1) + + # **4. Combine Masks ** + return block_diagonal | offset_block_causal | block_causal + + +def block_attn_mask(num_tokens, block_size, device): + masks = [] + for i in range(len(num_tokens)): + cur_masks = [] + for num in num_tokens[i]: + # 全部返回 n*n 而非 2n*2n + single_mask = block_diff_mask( + b=None, + h=None, + q_idx=torch.arange(num * 2, device=device)[:, None], + kv_idx=torch.arange(num * 2, device=device)[None, :], + block_size=block_size, + n=num, + ) + cur_masks.append(single_mask) + masks.append(torch.block_diag(*cur_masks)) + masks = torch.stack(masks, dim=0) + return masks + + +if is_torch_flex_attn_available(): + + @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs") + def fused_flex_attention(query, key, value, attention_mask, **kwargs): + return flex_attention(query, key, value, block_mask=attention_mask, **kwargs) + +else: # pragma: no cover + + def fused_flex_attention(*_args, **_kwargs): + raise ImportError( + "SDAR requires PyTorch Flex Attention. " + "Upgrade PyTorch to a build that includes `torch.nn.attention.flex_attention`." + ) + + +@use_kernel_forward_from_hub("RMSNorm") +class SDARRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + SDARRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + if flash_rms_norm is not None: + return flash_rms_norm( + hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon + ) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class SDARMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if liger_kernel_is_available: + return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn( + self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SDARAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SDARConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + # unlike olmo, only on the head dim! + self.q_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps) + # thus post q_norm does not need reshape + self.k_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + bsz, q_len = input_shape + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj( + hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj( + hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin) + + if past_key_value is not None and kwargs.get("store_kv", False): + # sin and cos are specific to RoPE models; cache_position needed for the static cache + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx) + elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx: + # only retrive, do not store kv + past_key_states, past_value_states = past_key_value[self.layer_idx] + key_states = torch.cat( + [past_key_states, key_states], dim=-2) + value_states = torch.cat( + [past_value_states, value_states], dim=-2) + + if self.training: + attn_output, attn_weights = fused_flex_attention( + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + enable_gqa=True, + scale=self.scaling, + return_lse=True + ) + attn_weights = attn_weights.to( + value_states.dtype) if attn_weights is not None else None + attn_output = rearrange(attn_output, 'b h l d -> b l (h d)') + else: + attention_mask = attention_mask.bool() if attention_mask is not None else None + attn_weights = None + if torch.all(attention_mask): # decoding + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + causal=False, + softmax_scale=self.scaling + ) + attn_output = rearrange(attn_output, 'b l h d -> b l (h d)') + else: # prefilling + attn_output = F.scaled_dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + attn_mask=attention_mask, + is_causal=False, + scale=self.scaling, + enable_gqa=True + ) + attn_output = rearrange(attn_output, 'b h l d -> b l (h d)') + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights # , attn_weights + + +class SDARDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SDARConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = SDARAttention(config=config, layer_idx=layer_idx) + self.mlp = SDARMLP(config) + self.input_layernorm = SDARRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = SDARRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + if ( + config.sliding_window and config._attn_implementation != "flash_attention_2" + ): # diff with Llama is this warning + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + store_kv: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + # necessary, but kept here for BC + position_embeddings: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + store_kv=store_kv, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class SDARPreTrainedModel(PreTrainedModel): + config_class = SDARConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SDARDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SDARRMSNorm): + module.weight.data.fill_(1.0) + + +class SDARRotaryEmbedding(nn.Module): + def __init__(self, config: SDARConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + # power user: used with advanced RoPE types (e.g. dynamic rope) + @dynamic_rope_update + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance( + x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ + position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class SDARModel(SDARPreTrainedModel): + def __init__(self, config: SDARConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [SDARDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = SDARRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = SDARRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + store_kv: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError( + "The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length( + ) if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # causal_mask = self._update_causal_mask( + # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + # ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + store_kv=store_kv, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, - + 1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + seq_len_q, seq_len_kv = attention_mask.shape + assert seq_len_q == seq_len_kv, f"got {attention_mask.shape=}" + attention_mask = create_block_mask( + # 2d bool tensor, shape: [2*seqlen, 2*seqlen] + lambda b, h, q_idx, kv_idx: attention_mask[q_idx, kv_idx], + B=None, H=None, Q_LEN=seq_len_q, KV_LEN=seq_len_kv, + ) + else: + # Here we pass in flex mask computed externally + assert isinstance(attention_mask, BlockMask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length( + ) if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance( + past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + config: SDARConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`SDARConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( + cache_position.reshape(-1, 1) - + text_config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, + :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + ... + + +@auto_docstring +class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = SDARModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_for_bd_training(self, inputs_ids, position_ids, prompt_mask): + bsz, seq_len = inputs_ids.shape + num_tokens = calculate_token_nums(position_ids) # List[torch.Tensor] + noisy_inputs_ids, logits_to_keep_half, p_mask = forward_add_noise_packed( + inputs_ids=inputs_ids, + num_tokens_list=num_tokens, + prompt_mask=prompt_mask, + mask_id=self.config.mask_token_id, + ) + router_noisy_part_list = [] + for i in range(bsz): + cur_router_noisy_part = (torch.arange(num_tokens[i].shape[0] *2) % 2 == 0).to(inputs_ids.device) + cur_router_noisy_part = cur_router_noisy_part.repeat_interleave(num_tokens[i].repeat_interleave(2)) + router_noisy_part_list.append(cur_router_noisy_part) + router_noisy_part = torch.stack(router_noisy_part_list, dim=0) + + # concated inputs_ids: (bzs, seq_len x 2) + concat_inputs_ids = inputs_ids.repeat(1, 2) + # concated logits_to_keep: (bsz, seq_len x 2) + logits_to_keep = torch.zeros( + bsz, 2 * seq_len, dtype=torch.bool, device=inputs_ids.device) + # concated position_ids: (bsz, seq_len x 2) + concat_position_ids = torch.zeros( + bsz, 2 * seq_len, dtype=position_ids.dtype, device=position_ids.device) + for i in range(bsz): + concat_inputs_ids[i][router_noisy_part[i]] = noisy_inputs_ids[i] + concat_inputs_ids[i][~router_noisy_part[i]] = inputs_ids[i] + + logits_to_keep[i][router_noisy_part[i]] = logits_to_keep_half[i] + + concat_position_ids[i][router_noisy_part[i]] = position_ids[i] + concat_position_ids[i][~router_noisy_part[i]] = position_ids[i] + + # create flex_attention mask + attention_mask = block_attn_mask(num_tokens, self.config.block_size, inputs_ids.device) + flex_attention_mask_3d = create_block_mask( + lambda b, h, q_idx, kv_idx: attention_mask[b, q_idx, kv_idx], + B=attention_mask.size(0), H=None, + Q_LEN=attention_mask.size(1), KV_LEN=attention_mask.size(2), + ) + + return concat_inputs_ids, concat_position_ids, flex_attention_mask_3d, logits_to_keep_half, logits_to_keep, p_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, SDARForCausalLM + + >>> model = SDARForCausalLM.from_pretrained("DiffuOpen/SDAR-1.7B-Chat") + >>> tokenizer = AutoTokenizer.from_pretrained("DiffuOpen/SDAR-1.7B-Chat") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if self.training: + assert inputs_embeds is None, "only support input_ids during training" + prompt_mask = (labels == -100) if labels is not None else None + position_ids = modify_padded_position_ids_2d(position_ids) + concat_inputs_ids, concat_position_ids, flex_attention_mask_3d, logits_to_keep_half, logits_to_keep, p_mask = self.prepare_for_bd_training(input_ids, position_ids, prompt_mask) + outputs = self.model( + input_ids=concat_inputs_ids, + attention_mask=flex_attention_mask_3d, + position_ids=concat_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + hidden_states = outputs.last_hidden_state + hidden_states = hidden_states[logits_to_keep].contiguous() + assert labels is not None, "Labels must be provided for training." + answer_len = (labels != -100).sum() + loss_fct = FusedLinearDiffusionCrossEntropyLoss(reduction='sum') + loss = loss_fct( # it will return (sum_loss, unreduced_loss) + # conduct `view(-1, V)` inside the function + x=hidden_states, + target=labels[logits_to_keep_half].contiguous(), + weight=self.lm_head.weight, + bias=self.lm_head.bias, + p_mask=p_mask, + ) + loss = loss / answer_len + logits = None + else: + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, + None) if isinstance(logits_to_keep, int) else logits_to_keep + hidden_states = hidden_states[:, slice_indices, :].contiguous() + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + if fuse_linear_and_cross_entropy: + # When using fused_linear_ce_loss, we do not compute the whole logits on HBM + logits = None + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # FusedLinearCrossEntropyLoss will be implemented by monkey patch when training + # We don't use it when inferencing + loss_fct = nn.CrossEntropyLoss() # nn.CE + loss = loss_fct( + logits.view(-1, self.config.vocab_size), labels.view(-1)) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SDARForCausalLM", + "SDARModel", + "SDARPreTrainedModel", +] diff --git a/src/utils/models.py b/src/utils/models.py index d5cd66b..2f30ea8 100644 --- a/src/utils/models.py +++ b/src/utils/models.py @@ -8,25 +8,33 @@ def load_pretrained_model(cfg: DictConfig, **model_kwargs) -> PreTrainedModel: """ Load a pretrained model based on the configuration. """ - from ..models import LLaDAModelLM, DreamModel + from ..models import LLaDAModelLM, DreamModel, SDARForCausalLM model_family = cfg.model.name.split("-")[0] if model_family == "llada": return LLaDAModelLM.from_pretrained(cfg.model.path, **model_kwargs) elif model_family == "dream": return DreamModel.from_pretrained(cfg.model.path, **model_kwargs) + elif model_family == "sdar": + # Prefer the in-repo SDAR implementation (avoids `trust_remote_code`). + # Avoid overriding SDAR's attention implementation (often "flex_attention") via global config. + model_kwargs.pop("attn_implementation", None) + model_kwargs.pop("trust_remote_code", None) + return SDARForCausalLM.from_pretrained(cfg.model.path, **model_kwargs) raise ValueError(f"Unsupported pretrained model: {cfg.model.name}") def load_eval_model(cfg: DictConfig, **model_kwargs): - from ..models import LLaDAEval, DreamEval + from ..models import LLaDAEval, DreamEval, SDAREval model_family = cfg.model.name.split("-")[0] if model_family == "llada": eval_model = LLaDAEval(cfg, **model_kwargs) elif model_family == "dream": eval_model = DreamEval(cfg, **model_kwargs) + elif model_family == "sdar": + eval_model = SDAREval(cfg, **model_kwargs) else: raise NotImplementedError( f"Model family {model_family} is not implemented for evaluation." @@ -72,4 +80,8 @@ def load_tokenizer(cfg: DictConfig, **tokenizer_kwargs): tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids( tokenizer.eot_token ) + case "sdar": + # SDAR relies on mask_token_id for block diffusion; expose it on the tokenizer. + # Some tokenizers don't define a mask token, but we only need the id. + tokenizer.mask_token_id = cfg.generation.mask_token_id # type: ignore[attr-defined] return tokenizer