From d7b9e42958e41f423e02aa3f1f4a724e69361225 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 04:47:53 +0200 Subject: [PATCH 01/25] update --- src/diffusers/models/attention.py | 477 +++++++++++++++++- src/diffusers/models/attention_processor.py | 248 ++------- .../models/transformers/transformer_flux.py | 457 ++++++++++++----- 3 files changed, 863 insertions(+), 319 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ae51d3ab1349..b174cb093d58 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,23 +11,494 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn import torch.nn.functional as F -from torch import nn from ..utils import deprecate, logging +from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU -from .attention_processor import Attention, JointAttnProcessor2_0 +from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +if is_xformers_available(): + import xformers as xops +else: + xops = None + + logger = logging.get_logger(__name__) +class AttentionMixin: + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + """ + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + for module in self.modules(): + if isinstance(module, AttentionModuleMixin): + module.fuse_projections() + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is ๐Ÿงช experimental. + + + """ + for module in self.modules(): + if isinstance(module, AttentionModuleMixin): + module.unfuse_projections() + + +class AttentionModuleMixin: + _default_processor_cls = None + _available_processors = [] + fused_projections = False + + def set_processor(self, processor: AttentionProcessor) -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + """ + Set whether to use NPU flash attention from `torch_npu` or not. + + Args: + use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. + """ + + if use_npu_flash_attention: + if not is_torch_npu_available(): + raise ImportError("torch_npu is not available") + + self.set_attention_backend("_native_npu") + + def set_use_xla_flash_attention( + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, + ) -> None: + """ + Set whether to use XLA flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + is_flux (`bool`, *optional*, defaults to `False`): + Whether the model is a Flux model. + """ + if use_xla_flash_attention: + if not is_torch_xla_available(): + raise ImportError("torch_xla is not available") + + self.set_attention_backend("_native_xla") + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + """ + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + if is_xformers_available(): + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xops.memory_efficient_attention(q, q, q) + except Exception as e: + raise e + + self.set_attention_backend("xformers") + + @torch.no_grad() + def fuse_projections(self): + """ + Fuse the query, key, and value projections into a single projection for efficiency. + """ + # Skip if already fused + if getattr(self, "fused_projections", False): + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if hasattr(self, "is_cross_attention") and self.is_cross_attention: + # Fuse cross-attention key-value projections + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if hasattr(self, "use_bias") and self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + else: + # Fuse self-attention projections + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if hasattr(self, "use_bias") and self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + # Handle added projections for models like SD3, Flux, etc. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + """ + Unfuse the query, key, and value projections back to separate projections. + """ + # Skip if not fused + if not getattr(self, "fused_projections", False): + return + + # Remove fused projection layers + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + + if hasattr(self, "to_added_qkv"): + delattr(self, "to_added_qkv") + + self.fused_projections = False + + def set_attention_slice(self, slice_size: int) -> None: + """ + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + processor = None + + # Try to get a compatible processor for sliced attention + if slice_size is not None: + processor = self._get_compatible_processor("sliced") + + # If no processor was found or slice_size is None, use default processor + if processor is None: + processor = self.default_processor_cls() + + self.set_processor(processor) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + """ + Reshape the tensor for multi-head attention processing. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + """ + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): The attention mask to prepare. + target_length (`int`): The target length of the attention mask. + batch_size (`int`): The batch size for repeating the attention mask. + out_dim (`int`, *optional*, defaults to `3`): Output dimension. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + """ + Normalize the encoder hidden states. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4760cfd40b3c..2306bdbc9dbd 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2272,102 +2272,16 @@ def __call__( return hidden_states -class FluxAttnProcessor2_0: +class FluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + deprecation_message = ( + "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." ) + deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" @@ -2470,107 +2384,16 @@ def __call__( return hidden_states -class FusedFluxAttnProcessor2_0: +class FusedFluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + deprecation_message = ( + "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." ) + deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" @@ -3459,6 +3282,12 @@ class XLAFluxFlashAttnProcessor2_0: """ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): + deprecation_message = ( + "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " + "alternative solution to using XLA Flash Attention will be provided in the future." + ) + deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." @@ -5992,17 +5821,6 @@ def __init__(self): pass -class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." - deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) - super().__init__() - - class SanaLinearAttnProcessor2_0: r""" Processor for implementing scaled dot-product linear attention. @@ -6167,6 +5985,40 @@ def __call__( return hidden_states +class FluxAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`" + deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxAttnProcessor + + return FluxAttnProcessor(*args, **kwargs) + + +class FluxSingleAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead." + deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxAttnProcessor + + return FluxAttnProcessor(*args, **kwargs) + + +class FusedFluxAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`" + deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxAttnProcessor + + return FluxAttnProcessor(*args, **kwargs) + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 608ea7f6e5f0..8682bcbb603e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -13,27 +13,26 @@ # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import ( - Attention, - AttentionProcessor, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..cache_utils import CacheMixin -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + FluxPosEmbed, + apply_rotary_emb, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -42,6 +41,322 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class FluxAttnProcessor: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None): + if hasattr(attn, "to_qkv") and attn.fused_projections: + return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) + return self._get_projections(attn, hidden_states, encoder_hidden_states) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if encoder_projections is not None: + encoder_query, encoder_key, encoder_value = encoder_projections + encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + # Concatenate for joint attention + query = torch.cat([encoder_query, query], dim=2) + key = torch.cat([encoder_key, key], dim=2) + value = torch.cat([encoder_value, value], dim=2) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxIPAdapterAttnProcessor(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + hidden_states_query_proj = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_query = hidden_states_query_proj + ip_attn_output = torch.zeros_like(hidden_states) + + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = torch.nn.functional.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + +class FluxAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = FluxAttnProcessor + _available_processors = [ + FluxAttnProcessor, + FluxIPAdapterAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + assert qk_norm == "rms_norm", "Flux uses RMSNorm" + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): @@ -54,6 +369,8 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) if is_torch_npu_available(): + from ..attention_processor import FluxAttnProcessor2_0_NPU + deprecation_message = ( "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "should be set explicitly using the `set_attn_processor` method." @@ -61,11 +378,10 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, deprecate("npu_processor", "0.34.0", deprecation_message) processor = FluxAttnProcessor2_0_NPU() else: - processor = FluxAttnProcessor2_0() + processor = FluxAttnProcessor() - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, @@ -118,16 +434,15 @@ def __init__( self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, - processor=FluxAttnProcessor2_0(), + processor=FluxAttnProcessor(), qk_norm=qk_norm, eps=eps, ) @@ -152,6 +467,7 @@ def forward( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, @@ -180,7 +496,6 @@ def forward( hidden_states = hidden_states + ip_attn_output # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output @@ -196,7 +511,13 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Flux. @@ -292,106 +613,6 @@ def __init__( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is ๐Ÿงช experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is ๐Ÿงช experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - def forward( self, hidden_states: torch.Tensor, From 7e97e43efcad3ac918dd1cd9ef92523136f8da17 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 04:56:55 +0200 Subject: [PATCH 02/25] update --- src/diffusers/models/embeddings.py | 41 +++++-------------- .../models/transformers/transformer_flux.py | 33 ++++++++++++++- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4f268bfa018f..262e57a3a050 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1238,37 +1238,6 @@ def apply_1d_rope(tokens, pos, cos, sin): return x -class FluxPosEmbed(nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: List[int]): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[:, i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin - - class TimestepEmbedding(nn.Module): def __init__( self, @@ -2619,3 +2588,13 @@ def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds.append(image_embed) return projected_image_embeds + + +class FluxPosEmbed(nn.Module): + def __new__(cls, *args, **kwargs): + deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`." + deprecate("FluxPosEmbed", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxPosEmbed + + return FluxPosEmbed(*args, **kwargs) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8682bcbb603e..8218b2ae6e93 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -30,8 +30,8 @@ from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, - FluxPosEmbed, apply_rotary_emb, + get_1d_rotary_pos_embed, ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -510,6 +510,37 @@ def forward( return encoder_hidden_states, hidden_states +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + class FluxTransformer2DModel( ModelMixin, ConfigMixin, From ecabd2a46e8dbe6e719f3045e368c012a0f36ff4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 07:38:04 +0200 Subject: [PATCH 03/25] add coauthor Co-Authored-By: Dhruv Nair From ff21b7fe8b0e3d2c0bbd0341c8273a1f4bb62a7c Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 07:46:32 +0200 Subject: [PATCH 04/25] improve test --- tests/pipelines/flux/test_pipeline_flux.py | 11 ++++------- tests/pipelines/test_pipelines_common.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 0df0e028ff06..4541521c8941 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -28,8 +28,7 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, + check_qkv_fused_layers_exist, ) @@ -171,12 +170,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 13c25ccaa469..387eb6a614f9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -37,6 +37,7 @@ from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin +from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import AttnProcessor from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel @@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model): return all(p.startswith("Fused") for p in proc_names) +def check_qkv_fused_layers_exist(model, layer_names): + is_fused_submodules = [] + for submodule in model.modules(): + if not isinstance(submodule, AttentionModuleMixin): + continue + is_fused_attribute_set = submodule.fused_projections + is_fused_layer = True + for layer in layer_names: + is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None + is_fused = is_fused_attribute_set and is_fused_layer + is_fused_submodules.append(is_fused) + return all(is_fused_submodules) + + class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. From b8f7fe61e1c5136cb8f88ee7ebe14c7b7c95fb13 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 08:21:47 +0200 Subject: [PATCH 05/25] handle ip adapter params correctly --- src/diffusers/loaders/ip_adapter.py | 9 +- src/diffusers/loaders/transformer_flux.py | 6 +- src/diffusers/models/attention_processor.py | 156 ++---------------- .../models/transformers/transformer_flux.py | 14 +- 4 files changed, 29 insertions(+), 156 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index e05d53687a24..dca4758ba038 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -40,8 +40,6 @@ from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, - FluxAttnProcessor2_0, - FluxIPAdapterJointAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, @@ -867,6 +865,9 @@ def unload_ip_adapter(self): >>> ... ``` """ + # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level + from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor + # remove CLIP image encoder if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: self.image_encoder = None @@ -886,9 +887,9 @@ def unload_ip_adapter(self): # restore original Transformer attention processors layers attn_procs = {} for name, value in self.transformer.attn_processors.items(): - attn_processor_class = FluxAttnProcessor2_0() + attn_processor_class = FluxAttnProcessor() attn_procs[name] = ( - attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() + attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__() ) self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index af03d09029c1..0873e8edd08f 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -87,9 +87,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( - FluxIPAdapterJointAttnProcessor2_0, - ) + from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor if low_cpu_mem_usage: if is_accelerate_available(): @@ -121,7 +119,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_ else: cross_attention_dim = self.config.joint_attention_dim hidden_size = self.inner_dim - attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 + attn_processor_class = FluxIPAdapterAttnProcessor num_image_text_embeds = [] for state_dict in state_dicts: if "proj.weight" in state_dict["image_proj"]: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2306bdbc9dbd..e64bd45eb42d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2501,152 +2501,6 @@ def __call__( return hidden_states -class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): - """Flux Attention processor for IP-Adapter.""" - - def __init__( - self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - self.to_v_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[List[torch.Tensor]] = None, - ip_adapter_masks: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - hidden_states_query_proj = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IP-adapter - ip_query = hidden_states_query_proj - ip_attn_output = torch.zeros_like(hidden_states) - - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states - - return hidden_states, encoder_hidden_states, ip_attn_output - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -6019,6 +5873,16 @@ def __new__(cls, *args, **kwargs): return FluxAttnProcessor(*args, **kwargs) +class FluxIPAdapterJointAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`" + deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxIPAdapterAttnProcessor + + return FluxIPAdapterAttnProcessor(*args, **kwargs) + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8218b2ae6e93..7898bdb1f010 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -241,7 +241,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -354,6 +356,14 @@ def forward( image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) From 0cda91d467636458bf77beb69cfa6ab62ceab324 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 07:51:58 +0200 Subject: [PATCH 06/25] fix chroma qkv fusion test --- tests/pipelines/chroma/test_pipeline_chroma.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index fc5749f96cd8..5121a2b52d75 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -7,12 +7,7 @@ from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.utils.testing_utils import torch_device -from ..test_pipelines_common import ( - FluxIPAdapterTesterMixin, - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaPipelineFastTests( @@ -126,12 +121,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images From bc64f12c98acb5237634e4864660d72563416a06 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 08:01:42 +0200 Subject: [PATCH 07/25] fix fastercache implementation --- src/diffusers/hooks/faster_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 1be5e1436294..a6c250b50ca4 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -18,6 +18,7 @@ import torch +from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..models.modeling_outputs import Transformer2DModelOutput from ..utils import logging @@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float: _apply_faster_cache_on_denoiser(module, config) for name, submodule in module.named_modules(): - if not isinstance(submodule, _ATTENTION_CLASSES): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): continue if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): _apply_faster_cache_on_attention_class(name, submodule, config) From a0b276da538e456a48153d7b28027d807a103b08 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 08:30:26 +0200 Subject: [PATCH 08/25] fix more tests --- .../hooks/pyramid_attention_broadcast.py | 3 ++- .../chroma/test_pipeline_chroma_img2img.py | 15 ++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index bbdd1c3f68d4..1c8787194196 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -18,6 +18,7 @@ import torch +from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..utils import logging from .hooks import HookRegistry, ModelHook @@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt config.spatial_attention_block_skip_range = 2 for name, submodule in module.named_modules(): - if not isinstance(submodule, _ATTENTION_CLASSES): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB # cannot be applied to this layer. For custom layers, users can extend this functionality and implement # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py index 02b20527b2f9..d518e1b7b8d1 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py +++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py @@ -8,12 +8,7 @@ from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.utils.testing_utils import floats_tensor, torch_device -from ..test_pipelines_common import ( - FluxIPAdapterTesterMixin, - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaImg2ImgPipelineFastTests( @@ -129,12 +124,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images From c1415207145b1417e0839fd39862b7eeb38bdaea Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 09:49:17 +0200 Subject: [PATCH 09/25] fight more tests --- .../models/transformers/transformer_flux.py | 2 +- .../test_controlnet_flux_img2img.py | 14 ++++---------- tests/pipelines/flux/test_pipeline_flux_control.py | 14 ++++---------- .../flux/test_pipeline_flux_control_inpaint.py | 14 ++++---------- 4 files changed, 13 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7898bdb1f010..706438569fca 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -242,7 +242,7 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) hidden_states = torch.nn.functional.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 8d63619c402b..ab4cf3273489 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -16,11 +16,7 @@ ) from diffusers.utils.torch_utils import randn_tensor -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -170,12 +166,10 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index d8d0774e1e32..42283da6fd03 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -8,11 +8,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import torch_device -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -140,12 +136,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py index a2f7c9171082..0abd08e37300 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py @@ -15,11 +15,7 @@ torch_device, ) -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -134,12 +130,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images From 4dcd6729071306251ea9f49639938c0ea9be0672 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:29:28 +0200 Subject: [PATCH 10/25] add back set_attention_backend --- src/diffusers/models/attention.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b174cb093d58..2d5eaaa69122 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -160,6 +160,16 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce """ if not return_deprecated_lora: return self.processor + + def set_attention_backend(self, backend: str): + from .attention_dispatch import AttentionBackendName + + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend.lower()) + self.processor._attention_backend = backend def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: """ From 576da52f45b77bb73ec7ef355011e5e8bc572f80 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 11:43:48 +0200 Subject: [PATCH 11/25] update --- src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention_dispatch.py | 1155 +++++++++++++++++ src/diffusers/models/modeling_utils.py | 50 + .../models/transformers/transformer_flux.py | 104 +- src/diffusers/utils/__init__.py | 6 + src/diffusers/utils/constants.py | 2 + src/diffusers/utils/import_utils.py | 60 + 8 files changed, 1332 insertions(+), 51 deletions(-) create mode 100644 src/diffusers/models/attention_dispatch.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bc81c24f7347..dd46e44991d3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -163,6 +163,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", + "AttentionBackendName", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -237,6 +238,7 @@ "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "attention_backend", ] ) _import_structure["modular_pipelines"].extend( @@ -809,6 +811,7 @@ from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, + AttentionBackendName, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -882,6 +885,7 @@ VQModel, WanTransformer3DModel, WanVACETransformer3DModel, + attention_backend, ) from .modular_pipelines import ( ComponentsManager, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 73903a627415..f019b35b0fe2 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -111,6 +112,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter + from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py new file mode 100644 index 000000000000..141a7fee858b --- /dev/null +++ b/src/diffusers/models/attention_dispatch.py @@ -0,0 +1,1155 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import inspect +import math +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import ( + get_logger, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_npu_available, + is_torch_version, + is_torch_xla_available, + is_torch_xla_version, + is_xformers_available, + is_xformers_version, +) +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): + from flash_attn import flash_attn_func, flash_attn_varlen_func +else: + logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") + flash_attn_func = None + flash_attn_varlen_func = None + + +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + + +if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + logger.warning( + "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." + ) + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if is_torch_version(">=", "2.5.0"): + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + import torch.nn.attention.flex_attention as flex_attention + + +if is_torch_npu_available(): + from torch_npu import npu_fusion_attention +else: + npu_fusion_attention = None + + +if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): + from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention +else: + xla_flash_attention = None + + +if is_xformers_available() and is_xformers_version(">=", "0.0.29"): + import xformers.ops as xops +else: + logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") + xops = None + + +# TODO(aryan): Add support for the following: +# - Sage Attention++ +# - block sparse, radial and other attention methods +# - CP with sage attention, flex, xformers, other missing backends +# - Add support for normal and CP training with backends that don't support it yet + + +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionBackendName(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + _FLASH_3 = "_flash_3" + _FLASH_VARLEN_3 = "_flash_varlen_3" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + _NATIVE_NPU = "_native_npu" + _NATIVE_XLA = "_native_xla" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionBackendRegistry: + _backends = {} + _constraints = {} + _supported_arg_names = {} + _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") + + def decorator(func): + cls._backends[backend] = func + cls._constraints[backend] = constraints or [] + cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_backend(cls): + return cls._active_backend, cls._backends[cls._active_backend] + + @classmethod + def list_backends(cls): + return list(cls._backends.keys()) + + +@contextlib.contextmanager +def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): + """ + Context manager to set the active attention backend. + """ + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") + + old_backend = _AttentionBackendRegistry._active_backend + _AttentionBackendRegistry._active_backend = backend + + try: + yield + finally: + _AttentionBackendRegistry._active_backend = old_backend + + +def dispatch_attention_fn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + *, + backend: Optional[AttentionBackendName] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + + if backend is None: + # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment + # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + else: + backend_name = AttentionBackendName(backend) + backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + **attention_kwargs, + } + + if _AttentionBackendRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") + for check in _AttentionBackendRegistry._constraints.get(backend_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) + + +# ===== Checks ===== +# A list of very simple functions to catch common errors quickly when debugging. + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +# ===== Helper functions ===== + + +@functools.lru_cache(maxsize=128) +def _prepare_for_flash_attn_or_sage_varlen_without_mask( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + device: Optional[torch.device] = None, +): + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size: int, + seq_len_q: int, + attn_mask: torch.Tensor, + device: Optional[torch.device] = None, +): + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, +) -> None: + if attn_mask is None: + return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device) + return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +# ===== torch op registrations ===== +# Registrations are required for fullgraph tracing compatibility + + +# TODO: library.custom_op and register_fake probably need version guards? +# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding +# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3_original( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = flash_attn_3_func(query, key, value) + lse = lse.permute(0, 2, 1) + return out, lse + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, num_heads, head_dim = query.shape + lse_shape = (batch_size, seq_len, num_heads) + return torch.empty_like(query), query.new_empty(lse_shape) + + +# ===== Attention backends ===== + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out, lse, *_ = flash_attn_3_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + attention_chunk=0, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_VARLEN_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out, lse, *_ = flash_attn_3_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.unflatten(0, (batch_size, -1)) + + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, seq_len_q, num_heads, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = flex_attention.create_block_mask( + _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = flex_attention.create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = flex_attention.flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=kernel_options, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE, + constraints=[_check_device, _check_shape], +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], +) +def _native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, # not supported + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_NPU, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_npu_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + return npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tokens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + +# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_XLA, + constraints=[_check_device, _check_shape], +) +def _native_xla_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + query = query / math.sqrt(query.shape[-1]) + out = xla_flash_attention( + q=query, + k=key, + v=value, + causal=is_causal, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="NHD", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, seq_len_q, num_heads_q, _ = query.shape + _, seq_len_kv, num_heads_kv, _ = key.shape + + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + + if enable_gqa: + out = out.flatten(2, 3) + + return out diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d7b2136b4afc..4918fae91d8a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -606,6 +606,56 @@ def enable_group_offload( offload_to_disk_path=offload_to_disk_path, ) + def set_attention_backend(self, backend: str) -> None: + """ + Set the attention backend for the model. + + Args: + backend (`str`): + The name of the backend to set. Must be one of the available backends defined in + `AttentionBackendName`. Available backends can be found in + `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product + attention as backend. + """ + from .attention import AttentionModuleMixin + from .attention_dispatch import AttentionBackendName + + # TODO: the following will not be required when everything is refactored to AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + backend = backend.lower() + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend) + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = backend + + def reset_attention_backend(self) -> None: + """ + Resets the attention backend for the model. Following calls to `forward` will use the environment default or + the torch native scaled dot product attention. + """ + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = None + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 706438569fca..2d4bc172a721 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -26,6 +26,7 @@ from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, @@ -42,6 +43,8 @@ class FluxAttnProcessor: + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") @@ -51,31 +54,25 @@ def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - encoder_projections = None - if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_projections = (encoder_query, encoder_key, encoder_value) - return query, key, value, encoder_projections + return query, key, value, encoder_query, encoder_key, encoder_value def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - encoder_projections = None + encoder_query = encoder_key = encoder_value = (None,) if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) - encoder_projections = (encoder_query, encoder_key, encoder_value) + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) - return query, key, value, encoder_projections + return query, key, value, encoder_query, encoder_key, encoder_value def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None): - if hasattr(attn, "to_qkv") and attn.fused_projections: + if attn.fused_projections: return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) return self._get_projections(attn, hidden_states, encoder_hidden_states) @@ -87,53 +84,43 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads + query, key, value, encoder_query, encoder_key, encoder_value = self.get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) - if encoder_projections is not None: - encoder_query, encoder_key, encoder_value = encoder_projections - encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) - # Concatenate for joint attention - query = torch.cat([encoder_query, query], dim=2) - key = torch.cat([encoder_key, key], dim=2) - value = torch.cat([encoder_value, value], dim=2) + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, backend=self._attention_backend + ) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) @@ -146,6 +133,8 @@ def __call__( class FluxIPAdapterAttnProcessor(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" + _attention_backend = None + def __init__( self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None ): @@ -241,8 +230,14 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -273,8 +268,14 @@ def __call__( ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = torch.nn.functional.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim @@ -323,6 +324,7 @@ def __init__( self.context_pre_only = context_pre_only self.pre_only = pre_only self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2df05cb8eb36..cadcedb98a14 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,9 @@ is_bitsandbytes_version, is_bs4_available, is_cosmos_guardrail_available, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, is_flax_available, is_ftfy_available, is_gguf_available, @@ -90,6 +93,8 @@ is_peft_version, is_pytorch_retinaface_available, is_safetensors_available, + is_sageattention_available, + is_sageattention_version, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, @@ -108,6 +113,7 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_xformers_version, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 7c04287d33ed..f8f04cc03abd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,6 +41,8 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 +DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f12e9de33172..a27c2da648f4 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -220,6 +220,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") +_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") def is_torch_available(): @@ -378,6 +381,18 @@ def is_hpu_available(): return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch")) +def is_sageattention_available(): + return _sageattention_available + + +def is_flash_attn_available(): + return _flash_attn_available + + +def is_flash_attn_3_available(): + return _flash_attn_3_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_xformers_version(operation: str, version: str): + """ + Compares the current xformers version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + """ + Compares the current sageattention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + """ + Compares the current flash-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects From e909b7355fcd7055334e245e532e79349df79a92 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:25:51 +0200 Subject: [PATCH 12/25] update --- src/diffusers/models/embeddings.py | 12 ++++++++++-- .../models/transformers/transformer_flux.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 262e57a3a050..4d3d246e4815 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1176,6 +1176,7 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, + sequence_dim: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -1193,8 +1194,15 @@ def apply_rotary_emb( """ if use_real: cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 2d4bc172a721..e5b45bbcd583 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -108,8 +108,8 @@ def __call__( value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, backend=self._attention_backend From 1e7217f82daa10d55b799419300bc746a9aebfc1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:30:34 +0200 Subject: [PATCH 13/25] make style --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2d5eaaa69122..c720b379551f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -160,7 +160,7 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce """ if not return_deprecated_lora: return self.processor - + def set_attention_backend(self, backend: str): from .attention_dispatch import AttentionBackendName From 4f52e3499c68c5f618f70362053947625415a061 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:30:45 +0200 Subject: [PATCH 14/25] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 247769306b53..4ec3a3234c96 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -258,6 +258,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AttentionBackendName(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1353,6 +1368,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def attention_backend(*args, **kwargs): + requires_backends(attention_backend, ["torch"]) + + class ComponentsManager(metaclass=DummyObject): _backends = ["torch"] From d9c1683b07a46cbfcb0c845a1facca7681157910 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 13:30:52 +0200 Subject: [PATCH 15/25] make ip adapter processor compatible with attention dispatcher --- .../models/transformers/transformer_flux.py | 142 ++++++++---------- 1 file changed, 59 insertions(+), 83 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index e5b45bbcd583..7640d8d13cdc 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -42,39 +42,42 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FluxAttnProcessor: - _attention_backend = None +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) - def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) + return query, key, value, encoder_query, encoder_key, encoder_value - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - return query, key, value, encoder_query, encoder_key, encoder_value +def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) - encoder_query = encoder_key = encoder_value = (None,) - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + return query, key, value, encoder_query, encoder_key, encoder_value - return query, key, value, encoder_query, encoder_key, encoder_value - def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None): - if attn.fused_projections: - return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) - return self._get_projections(attn, hidden_states, encoder_hidden_states) +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class FluxAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") def __call__( self, @@ -84,7 +87,7 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - query, key, value, encoder_query, encoder_key, encoder_value = self.get_qkv_projections( + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( attn, hidden_states, encoder_hidden_states ) @@ -180,55 +183,35 @@ def __call__( ip_hidden_states: Optional[List[torch.Tensor]] = None, ip_adapter_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + batch_size = hidden_states.shape[0] - # `sample` projections. - hidden_states_query_proj = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) - hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) - if attn.norm_q is not None: - hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) + ip_query = query - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, @@ -239,23 +222,18 @@ def __call__( is_causal=False, backend=self._attention_backend, ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - - # linear proj hidden_states = attn.to_out[0](hidden_states) - # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # IP-adapter - ip_query = hidden_states_query_proj ip_attn_output = torch.zeros_like(hidden_states) for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( @@ -264,10 +242,9 @@ def __call__( ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) + current_ip_hidden_states = dispatch_attention_fn( ip_query, ip_key, @@ -277,9 +254,7 @@ def __call__( is_causal=False, backend=self._attention_backend, ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) ip_attn_output += scale * current_ip_hidden_states @@ -316,6 +291,7 @@ def __init__( super().__init__() assert qk_norm == "rms_norm", "Flux uses RMSNorm" + self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim self.use_bias = bias From a73cb396bacfd4aa36242d85a4868ba7d7e62b5c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 13:53:27 +0200 Subject: [PATCH 16/25] refactor chroma as well --- .../models/transformers/transformer_chroma.py | 130 ++---------------- 1 file changed, 15 insertions(+), 115 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 0f6dd677ac5c..bf4df9df93d1 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -24,19 +24,13 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import ( - Attention, - AttentionProcessor, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm +from .transformer_flux import FluxAttention, FluxAttnProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -223,6 +217,8 @@ def __init__( self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) if is_torch_npu_available(): + from ..attention_processor import FluxAttnProcessor2_0_NPU + deprecation_message = ( "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "should be set explicitly using the `set_attn_processor` method." @@ -230,11 +226,10 @@ def __init__( deprecate("npu_processor", "0.34.0", deprecation_message) processor = FluxAttnProcessor2_0_NPU() else: - processor = FluxAttnProcessor2_0() + processor = FluxAttnProcessor() - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, @@ -292,16 +287,15 @@ def __init__( self.norm1 = ChromaAdaLayerNormZeroPruned(dim) self.norm1_context = ChromaAdaLayerNormZeroPruned(dim) - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, - processor=FluxAttnProcessor2_0(), + processor=FluxAttnProcessor(), qk_norm=qk_norm, eps=eps, ) @@ -376,7 +370,13 @@ def forward( class ChromaTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Flux, modified for Chroma. @@ -475,106 +475,6 @@ def __init__( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is ๐Ÿงช experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is ๐Ÿงช experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - def forward( self, hidden_states: torch.Tensor, From 1e6b1c51a8691c144adb50f2c9d21c53adf5d5e2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:39:30 +0200 Subject: [PATCH 17/25] remove rmsnorm assert --- src/diffusers/models/transformers/transformer_chroma.py | 2 -- src/diffusers/models/transformers/transformer_flux.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index bf4df9df93d1..5823ae9d3da6 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -235,7 +235,6 @@ def __init__( out_dim=dim, bias=True, processor=processor, - qk_norm="rms_norm", eps=1e-6, pre_only=True, ) @@ -296,7 +295,6 @@ def __init__( context_pre_only=False, bias=True, processor=FluxAttnProcessor(), - qk_norm=qk_norm, eps=eps, ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7640d8d13cdc..9080cd508de4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -277,7 +277,6 @@ def __init__( dim_head: int = 64, dropout: float = 0.0, bias: bool = False, - qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, out_bias: bool = True, @@ -289,7 +288,6 @@ def __init__( processor=None, ): super().__init__() - assert qk_norm == "rms_norm", "Flux uses RMSNorm" self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -375,7 +373,6 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, out_dim=dim, bias=True, processor=processor, - qk_norm="rms_norm", eps=1e-6, pre_only=True, ) @@ -431,7 +428,6 @@ def __init__( context_pre_only=False, bias=True, processor=FluxAttnProcessor(), - qk_norm=qk_norm, eps=eps, ) From 251bb619250900c60273621a071bef22b1971dca Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:50:54 +0200 Subject: [PATCH 18/25] minify and deprecate npu/xla processors --- src/diffusers/models/attention_processor.py | 396 +++----------------- 1 file changed, 61 insertions(+), 335 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e64bd45eb42d..990245de1742 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2272,235 +2272,6 @@ def __call__( return hidden_states -class FluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -3130,112 +2901,6 @@ def __call__( return hidden_states -class XLAFluxFlashAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. - """ - - def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): - deprecation_message = ( - "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " - "alternative solution to using XLA Flash Attention will be provided in the future." - ) - deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - if is_torch_xla_version("<", "2.3"): - raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") - if is_spmd() and is_torch_xla_version("<", "2.4"): - raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") - self.partition_spec = partition_spec - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - query /= math.sqrt(head_dim) - hidden_states = flash_attention(query, key, value, causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class MochiVaeAttnProcessor2_0: r""" Attention processor used in Mochi VAE. @@ -5883,6 +5548,67 @@ def __new__(cls, *args, **kwargs): return FluxIPAdapterAttnProcessor(*args, **kwargs) +class FluxAttnProcessor2_0_NPU: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_native_npu" + return processor + + +class FusedFluxAttnProcessor2_0_NPU: + def __new__(self): + deprecation_message = ( + "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_fused_npu" + return processor + + +class XLAFluxFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. + """ + + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " + "alternative solution to using XLA Flash Attention will be provided in the future." + ) + deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + + from .transformers.transformer_flux import FluxAttnProcessor + + if len(args) > 0 or kwargs.get("partition_spec", None) is not None: + deprecation_message = ( + "partition_spec was not used in the processor implementation when it was added. Passing it " + "is a no-op and support for it will be removed." + ) + deprecate("partition_spec", "1.0.0", deprecation_message) + + processor = FluxAttnProcessor(*args, **kwargs) + processor._attention_backend = "_native_xla" + return processor + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, From 51fed50837edd36ec9a708c499f0bde79e358325 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 17:19:30 +0200 Subject: [PATCH 19/25] update --- src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/context_parallel.py | 275 +++++++ src/diffusers/models/_modeling_parallel.py | 105 +++ src/diffusers/models/attention_dispatch.py | 683 ++++++++++++++---- src/diffusers/models/modeling_utils.py | 47 ++ .../models/transformers/transformer_flux.py | 10 + 6 files changed, 972 insertions(+), 149 deletions(-) create mode 100644 src/diffusers/hooks/context_parallel.py create mode 100644 src/diffusers/models/_modeling_parallel.py diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 525a0747da8b..524a92ea9966 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py new file mode 100644 index 000000000000..c7b88f5b8df0 --- /dev/null +++ b/src/diffusers/hooks/context_parallel.py @@ -0,0 +1,275 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Dict, List, Type, Union + +import torch +import torch.distributed._functional_collectives as funcol + +from ..models._modeling_parallel import ( + ContextParallelInput, + ContextParallelModelPlan, + ContextParallelOutput, + ParallelConfig, +) +from ..models.attention_dispatch import _parallel_context +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook" +_CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE = "cp_input---{}" +_CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" + + +# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata +@dataclass +class ModuleForwardMetadata: + cached_parameter_indices: Dict[str, int] = None + _cls: Type = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + + if identifier in kwargs: + return kwargs[identifier], True, None + + if self.cached_parameter_indices is not None: + index = self.cached_parameter_indices.get(identifier, None) + if index is None: + raise ValueError(f"Parameter '{identifier}' not found in cached indices.") + return args[index], False, index + + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # skip `self` + self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + + if identifier not in self.cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + + index = self.cached_parameter_indices[identifier] + + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + + return args[index], False, index + + +def apply_context_parallel( + module: torch.nn.Module, + parallel_config: ParallelConfig, + plan: Dict[str, ContextParallelModelPlan], +) -> None: + """Apply context parallel on a model.""" + logger.debug(f"Applying context parallel with CP mesh: {parallel_config.cp_mesh} and plan: {plan}") + + for module_id, cp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") + + for m in submodule: + if isinstance(cp_model_plan, dict): + hook = ContextParallelSplitHook(cp_model_plan, parallel_config) + hook_name = _CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): + if isinstance(cp_model_plan, ContextParallelOutput): + cp_model_plan = [cp_model_plan] + if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): + raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") + hook = ContextParallelGatherHook(cp_model_plan, parallel_config) + hook_name = _CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") + registry = HookRegistry.check_if_exists_or_initialize(m) + registry.register_hook(hook, hook_name) + + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = ContextParallelModelHook(parallel_config) + registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK) + + +class ContextParallelModelHook(ModelHook): + def __init__(self, parallel_config: ParallelConfig) -> None: + super().__init__() + self.parallel_config = parallel_config + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + with _parallel_context(self.parallel_config): + return self.fn_ref.original_forward(*args, **kwargs) + + +class ContextParallelSplitHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None: + super().__init__() + self.metadata = metadata + self.parallel_config = parallel_config + self.module_forward_metadata = None + + def initialize_hook(self, module): + cls = unwrap_module(module).__class__ + self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) + return module + + def pre_forward(self, module, *args, **kwargs): + args_list = list(args) + + for name, cpm in self.metadata.items(): + if isinstance(cpm, ContextParallelInput) and cpm.split_output: + continue + + # Maybe the parameter was passed as a keyword argument + input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( + name, args_list, kwargs + ) + + if input_val is None: + continue + + # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard + # the output instead of input for a particular layer by setting split_output=True + if isinstance(input_val, torch.Tensor): + input_val = self._prepare_cp_input(input_val, cpm) + elif isinstance(input_val, (list, tuple)): + if len(input_val) != len(cpm): + raise ValueError( + f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." + ) + sharded_input_val = [] + for i, x in enumerate(input_val): + if torch.is_tensor(x) and not cpm[i].split_output: + x = self._prepare_cp_input(x, cpm[i]) + sharded_input_val.append(x) + input_val = sharded_input_val + else: + raise ValueError(f"Unsupported input type: {type(input_val)}") + + if is_kwarg: + kwargs[name] = input_val + elif index is not None and index < len(args_list): + args_list[index] = input_val + else: + raise ValueError( + f"An unexpected error occurred while processing the input '{name}'. Please open an " + f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " + f"example along with the full stack trace." + ) + + return tuple(args_list), kwargs + + def post_forward(self, module, output): + is_tensor = isinstance(output, torch.Tensor) + is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output) + + if not is_tensor and not is_tensor_list: + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + + output = [output] if is_tensor else list(output) + for index, cpm in self.metadata.items(): + if not isinstance(cpm, ContextParallelInput) or not cpm.split_output: + continue + if index >= len(output): + raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") + current_output = output[index] + current_output = self._prepare_cp_input(current_output, cpm) + output[index] = current_output + + return output[0] if is_tensor else tuple(output) + + def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: + if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: + raise ValueError( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + ) + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + + +class ContextParallelGatherHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None: + super().__init__() + self.metadata = metadata + self.parallel_config = parallel_config + + def post_forward(self, module, output): + is_tensor = isinstance(output, torch.Tensor) + + if is_tensor: + output = [output] + elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + + output = list(output) + + if len(output) != len(self.metadata): + raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") + + for i, cpm in enumerate(self.metadata): + if cpm is None: + continue + output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + + return output[0] if is_tensor else tuple(output) + + +class EquipartitionSharder: + @classmethod + @torch.compiler.disable + def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + assert tensor.size()[dim] % mesh.size() == 0 + return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] + + @classmethod + @torch.compiler.disable + def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + tensor = tensor.contiguous() + tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group()) + return tensor + + +def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: + if name.count("*") > 1: + raise ValueError("Wildcard '*' can only be used once in the name") + return _find_submodule_by_name(model, name) + + +def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: + if name == "": + return model + first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") + if first_atom == "*": + if not isinstance(model, torch.nn.ModuleList): + raise ValueError("Wildcard '*' can only be used with ModuleList") + submodules = [] + for submodule in model: + subsubmodules = _find_submodule_by_name(submodule, remaining_name) + if not isinstance(subsubmodules, list): + subsubmodules = [subsubmodules] + submodules.extend(subsubmodules) + return submodules + else: + if hasattr(model, first_atom): + submodule = getattr(model, first_atom) + return _find_submodule_by_name(submodule, remaining_name) + else: + raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py new file mode 100644 index 000000000000..4fa7df3c47a2 --- /dev/null +++ b/src/diffusers/models/_modeling_parallel.py @@ -0,0 +1,105 @@ +# Experimental parallelism support for Diffusers. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(aryan): add support for the following: +# - Unified Attention +# - More dispatcher attention backends +# - CFG/Data Parallel +# - Tensor Parallel + + +@dataclass +class ParallelConfig: + rank: int + world_size: int + ring_degree: int + ulysses_degree: int + device: torch.device + cp_mesh: torch.distributed.device_mesh.DeviceMesh + + # Whether to convert output and LSE to float32 for ring attention numerical stability + convert_to_fp32: bool = True + # TODO: support alltoall + rotate_method: Literal["allgather", "alltoall"] = "allgather" + + _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None + _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None + _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None + _ring_local_rank: int = None + _ulysses_local_rank: int = None + + def __post_init__(self): + if self.rotate_method != "allgather": + raise ValueError(f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}.") + if self._flattened_mesh is None: + self._flattened_mesh = self.cp_mesh._flatten() + if self._ring_mesh is None: + self._ring_mesh = self.cp_mesh["ring"] + if self._ulysses_mesh is None: + self._ulysses_mesh = self.cp_mesh["ulysses"] + if self._ring_local_rank is None: + self._ring_local_rank = self._ring_mesh.get_local_rank() + if self._ulysses_local_rank is None: + self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() + + +@dataclass(frozen=True) +class ContextParallelInput: + split_dim: int + expected_dims: Optional[int] = None + split_output: bool = False + + def __repr__(self): + return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})" + + +@dataclass(frozen=True) +class ContextParallelOutput: + gather_dim: int + expected_dims: Optional[int] = None + + def __repr__(self): + return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})" + + +# A dictionary where keys denote the input to be split across context parallel region, and the +# value denotes the sharding configuration. +# If the key is a string, it denotes the name of the parameter in the forward function. +# If the key is an integer, split_output must be set to True, and it denotes the index of the output +# to be split across context parallel region. +ContextParallelInputType = Dict[ + Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]] +] + +# A dictionary where keys denote the output to be gathered across context parallel region, and the +# value denotes the gathering configuration. +ContextParallelOutputType = Union[ + ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...] +] + +# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of +# the module should be split/gathered across context parallel region. +ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]] diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 141a7fee858b..483efc130345 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,9 +17,10 @@ import inspect import math from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch +import torch.distributed._functional_collectives as funcol from ..utils import ( get_logger, @@ -38,15 +39,22 @@ from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +if TYPE_CHECKING: + from ._modeling_parallel import ParallelConfig + + logger = get_logger(__name__) # pylint: disable=invalid-name if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward else: logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") flash_attn_func = None flash_attn_varlen_func = None + _flash_attn_forward = None + _flash_attn_backward = None if is_flash_attn_3_available(): @@ -104,6 +112,27 @@ xops = None +if torch.__version__ >= "2.4.0": + _custom_op = torch.library.custom_op + _register_fake = torch.library.register_fake +else: + + def _custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + + return wrap if fn is None else fn + + def _register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + + return wrap if fn is None else fn + + _custom_op = _custom_op_no_op + _register_fake = _register_fake_no_op + + # TODO(aryan): Add support for the following: # - Sage Attention++ # - block sparse, radial and other attention methods @@ -154,17 +183,25 @@ class _AttentionBackendRegistry: _backends = {} _constraints = {} _supported_arg_names = {} + _supports_context_parallel = {} _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS + _parallel_config: Optional["ParallelConfig"] = None @classmethod - def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + def register( + cls, + backend: AttentionBackendName, + constraints: Optional[List[Callable]] = None, + supports_context_parallel: bool = False, + ): logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") def decorator(func): cls._backends[backend] = func cls._constraints[backend] = constraints or [] cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + cls._supports_context_parallel[backend] = supports_context_parallel return func return decorator @@ -177,6 +214,17 @@ def get_active_backend(cls): def list_backends(cls): return list(cls._backends.keys()) + @classmethod + def _is_context_parallel_enabled(cls, backend: AttentionBackendName) -> bool: + if backend not in cls._supports_context_parallel: + raise ValueError(f"Backend {backend} is not registered.") + supports_context_parallel = cls._supports_context_parallel[backend] + is_degree_greater_than_1 = _AttentionBackendRegistry._parallel_config is not None and ( + _AttentionBackendRegistry._parallel_config.ring_degree > 1 + or _AttentionBackendRegistry._parallel_config.ulysses_degree > 1 + ) + return supports_context_parallel and is_degree_greater_than_1 + @contextlib.contextmanager def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): @@ -195,6 +243,20 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV _AttentionBackendRegistry._active_backend = old_backend +@contextlib.contextmanager +def _parallel_context(parallel_config: "ParallelConfig"): + """ + Context manager to set the parallel configuration for attention backends that support it. + """ + old_parallel_config = _AttentionBackendRegistry._parallel_config + _AttentionBackendRegistry._parallel_config = parallel_config + + try: + yield + finally: + _AttentionBackendRegistry._parallel_config = old_parallel_config + + def dispatch_attention_fn( query: torch.Tensor, key: torch.Tensor, @@ -218,6 +280,14 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + if ( + _AttentionBackendRegistry._parallel_config is not None + and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name) + ): + raise ValueError( + f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided." + ) + kwargs = { "query": query, "key": key, @@ -415,20 +485,398 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # TODO: library.custom_op and register_fake probably need version guards? # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3_original( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = flash_attn_3_func(query, key, value) + # Hardcoded for now because pytorch does not support tuple/int type hints + window_size = (-1, -1) + out, lse, *_ = flash_attn_3_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + sm_margin=sm_margin, + ) lse = lse.permute(0, 2, 1) return out, lse -@torch.library.register_fake("flash_attn_3::_flash_attn_forward") -def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len, num_heads, head_dim = query.shape +@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") +def _( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + window_size = (-1, -1) # noqa: F841 + # A lot of the parameters here are not yet used in any way within diffusers. + # We can safely ignore for now and keep the fake op shape propagation simple. + batch_size, seq_len, num_heads, head_dim = q.shape lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(query), query.new_empty(lse_shape) + return torch.empty_like(q), q.new_empty(lse_shape) + + +# ===== Autograd functions ===== + + +class _cudnn_attention(torch.autograd.Function): + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 + # forward declaration: + # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + # backward declaration: + # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + enable_gqa: bool = False, + return_lse: bool = False, + ): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") + + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask + + # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results + # if the input tensors are not contiguous. + query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + torch.ops.aten._scaled_dot_product_cudnn_attention( + query=query, + key=key, + value=value, + attn_bias=attn_mask, + compute_log_sumexp=return_lse, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + ) + ) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + + out = out.transpose(1, 2).contiguous() + if lse is not None: + lse = lse.transpose(1, 2).contiguous() + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + grad_out = grad_out.transpose(1, 2).contiguous() + + # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341 + grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=ctx.attn_mask, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) + + return grad_query, grad_key, grad_value, None, None, None, None, None + + +# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 +class _flash_attention_2(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + enable_gqa: bool = False, + return_lse: bool = False, + ): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") + + # Hardcoded for now + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. + parallel_config = _AttentionBackendRegistry._parallel_config + if query.requires_grad or (parallel_config is not None and parallel_config.world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + out, lse, S_dmask, rng_state = _flash_attn_forward( + query, + key, + value, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + + ctx.save_for_backward(query, key, value, out, lse, rng_state) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + query, key, value, out, lse, rng_state = ctx.saved_tensors + grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + + lse_d = _flash_attn_backward( # noqa: F841 + grad_out, + query, + key, + value, + out, + lse, + grad_query, + grad_key, + grad_value, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + # Head dimension may have been padded + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +# ===== Context parallel ===== + + +class TemplatedRingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + return_lse: bool, + op: torch.autograd.Function, + ): + parallel_config = _AttentionBackendRegistry._parallel_config + ring_mesh = parallel_config._ring_mesh + rank = parallel_config._ring_local_rank + world_size = parallel_config.ring_degree + + next_rank = (rank + 1) % world_size + prev_out = prev_lse = None + + kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() + kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) + kv_buffer = kv_buffer.chunk(world_size) + + for i in range(world_size): + if i > 0: + kv = kv_buffer[next_rank] + key = kv[: key.numel()].reshape_as(key) + value = kv[key.numel() :].reshape_as(value) + next_rank = (next_rank + 1) % world_size + + out, lse = op.apply(query, key, value, None, 0.0, None, False, False, True) + + if parallel_config.convert_to_fp32: + out = out.to(torch.float32) + lse = lse.to(torch.float32) + + lse = lse.unsqueeze(-1) + if prev_out is not None: + out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) + lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) + prev_out = out + prev_lse = lse + + out = out.to(query.dtype) + lse = lse.squeeze(-1) + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.") + + +class TemplatedUlyssesAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + return_lse: bool, + op: torch.autograd.Function, + ): + parallel_config = _AttentionBackendRegistry._parallel_config + ulysses_mesh = parallel_config._ulysses_mesh + world_size = parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + B, S_LOCAL, H, D = query.shape + H_LOCAL = H // world_size + query, key, value = ( + x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + for x in (query, key, value) + ) + query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value)) + query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) + + out = op.apply(query, key, value, None, 0.0, None, False, False, return_lse) + if return_lse: + out, lse, *_ = out + + out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = funcol.all_to_all_single(out, None, None, group=group).wait() + out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + if return_lse: + lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() + lse = funcol.all_to_all_single(lse, None, None, group=group).wait() + lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() + else: + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.") + + +def _templated_context_parallel_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + *, + op: torch.autograd.Function, +): + if attn_mask is not None: + raise ValueError("Attention mask is not yet supported for templated attention.") + if is_causal: + raise ValueError("Causal attention is not yet supported for templated attention.") + if enable_gqa: + raise ValueError("GQA is not yet supported for templated attention.") + + parallel_config = _AttentionBackendRegistry._parallel_config + # TODO: add support for unified attention with ring/ulysses degree both being > 1 + if parallel_config.ring_degree > 1: + return TemplatedRingAttention.apply(query, key, value, return_lse, op) + elif parallel_config.ulysses_degree > 1: + return TemplatedUlyssesAttention.apply(query, key, value, return_lse, op) + else: + return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) # ===== Attention backends ===== @@ -445,11 +893,7 @@ def _flash_attention( dropout_p: float = 0.0, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False, - return_attn_probs: bool = False, + return_lse: bool = False, ) -> torch.Tensor: out = flash_attn_func( q=query, @@ -458,11 +902,7 @@ def _flash_attention( dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, + return_attn_probs=return_lse, ) return out @@ -475,19 +915,11 @@ def _flash_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False, - return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, + return_lse: bool = False, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape @@ -495,16 +927,11 @@ def _flash_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -527,11 +954,7 @@ def _flash_varlen_attention( dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, + return_attn_probs=return_lse, ) out = out.unflatten(0, (batch_size, -1)) @@ -548,30 +971,16 @@ def _flash_attention_3( value: torch.Tensor, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, + return_lse: bool = False, ) -> torch.Tensor: - out, lse, *_ = flash_attn_3_func( + out, lse = _wrapped_flash_attn_3( q=query, k=key, v=value, softmax_scale=scale, causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - attention_chunk=0, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, ) - return (out, lse) if return_attn_probs else out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -582,17 +991,10 @@ def _flash_varlen_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, - window_size: Tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, + return_lse: bool = False, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape @@ -600,16 +1002,11 @@ def _flash_varlen_attention_3( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -629,24 +1026,12 @@ def _flash_varlen_attention_3( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - seqused_q=None, - seqused_k=None, softmax_scale=scale, causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - softcap=softcap, - num_splits=1, - pack_gqa=None, - deterministic=deterministic, - sm_margin=0, ) out = out.unflatten(0, (batch_size, -1)) - return (out, lse) if return_attn_probs else out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -662,7 +1047,6 @@ def _native_flex_attention( scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, - kernel_options: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: # TODO: should we LRU cache the block mask creation? score_mod = None @@ -707,7 +1091,6 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): scale=scale, enable_gqa=enable_gqa, return_lse=return_lse, - kernel_options=kernel_options, ) out = out.permute(0, 2, 1, 3) return out @@ -726,7 +1109,10 @@ def _native_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( query=query, @@ -745,6 +1131,7 @@ def _native_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_CUDNN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_cudnn_attention( query: torch.Tensor, @@ -755,21 +1142,33 @@ def _native_cudnn_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, ) -> torch.Tensor: - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, + parallel_config = _AttentionBackendRegistry._parallel_config + + lse = None + if parallel_config is None and not return_lse: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention ) - out = out.permute(0, 2, 1, 3) - return out + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -785,7 +1184,10 @@ def _native_efficient_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( @@ -814,7 +1216,10 @@ def _native_flash_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native flash attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( @@ -844,7 +1249,10 @@ def _native_math_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("Native math attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): out = torch.nn.functional.scaled_dot_product_attention( @@ -871,7 +1279,10 @@ def _native_npu_attention( value: torch.Tensor, dropout_p: float = 0.0, scale: Optional[float] = None, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") return npu_fusion_attention( query, key, @@ -898,7 +1309,10 @@ def _native_xla_attention( key: torch.Tensor, value: torch.Tensor, is_causal: bool = False, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("XLA attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query = query / math.sqrt(query.shape[-1]) out = xla_flash_attention( @@ -942,31 +1356,25 @@ def _sage_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_k: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, - smooth_k: bool = True, - attn_mask: Optional[torch.Tensor] = None, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") + batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) - else: - seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) - cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) - cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + ) key_valid, value_valid = [], [] for b in range(batch_size): @@ -988,7 +1396,6 @@ def _sage_varlen_attention( max_seqlen_k=max_seqlen_k, is_causal=is_causal, sm_scale=scale, - smooth_k=smooth_k, ) out = out.unflatten(0, (batch_size, -1)) @@ -1005,10 +1412,6 @@ def _sage_qk_int8_pv_fp8_cuda_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", - smooth_k: bool = True, - smooth_v: bool = False, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp8_cuda( @@ -1017,11 +1420,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention( v=value, tensor_layout="NHD", is_causal=is_causal, - qk_quant_gran=qk_quant_gran, sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - smooth_v=smooth_v, return_lse=return_lse, ) @@ -1036,9 +1435,6 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", - smooth_k: bool = True, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp8_cuda_sm90( @@ -1047,10 +1443,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( v=value, tensor_layout="NHD", is_causal=is_causal, - qk_quant_gran=qk_quant_gran, sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, return_lse=return_lse, ) @@ -1065,10 +1458,6 @@ def _sage_qk_int8_pv_fp16_cuda_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", - smooth_k: bool = True, - smooth_v: bool = False, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp16_cuda( @@ -1077,11 +1466,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention( v=value, tensor_layout="NHD", is_causal=is_causal, - qk_quant_gran=qk_quant_gran, sm_scale=scale, - pv_accum_dtype=pv_accum_dtype, - smooth_k=smooth_k, - smooth_v=smooth_v, return_lse=return_lse, ) @@ -1096,8 +1481,6 @@ def _sage_qk_int8_pv_fp16_triton_attention( value: torch.Tensor, is_causal: bool = False, scale: Optional[float] = None, - quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", - smooth_k: bool = True, return_lse: bool = False, ) -> torch.Tensor: return sageattn_qk_int8_pv_fp16_triton( @@ -1105,10 +1488,8 @@ def _sage_qk_int8_pv_fp16_triton_attention( k=key, v=value, tensor_layout="NHD", - quantization_backend=quantization_backend, is_causal=is_causal, sm_scale=scale, - smooth_k=smooth_k, return_lse=return_lse, ) @@ -1126,7 +1507,11 @@ def _xformers_attention( is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, + return_lse: bool = False, ) -> torch.Tensor: + if return_lse: + raise ValueError("xformers attention backend does not support setting `return_lse=True`.") + batch_size, seq_len_q, num_heads_q, _ = query.shape _, seq_len_kv, num_heads_kv, _ = key.shape diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 01ebb1a91027..8dba8e45f5a7 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -271,6 +271,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _skip_layerwise_casting_patterns = None _supports_group_offloading = True _repeated_blocks = [] + _cp_plan = None def __init__(self): super().__init__() @@ -1492,6 +1493,52 @@ def compile_repeated_blocks(self, *args, **kwargs): f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " ) + def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=None): + from ..hooks.context_parallel import ParallelConfig, apply_context_parallel + + # TODO(aryan): add cp_plan type hint + logger.warning( + "`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." + ) + + if not torch.distributed.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.") + if ring_degree < 1 or ulysses_degree < 1: + raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") + if ring_degree > 1 and ulysses_degree > 1: + raise ValueError( + "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + ) + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + if ring_degree * ulysses_degree > world_size: + raise ValueError( + f"The product of `ring_degree` ({ring_degree}) and `ulysses_degree` ({ulysses_degree}) must not exceed the world size ({world_size})." + ) + + device_type = torch._C._get_accelerator().type + device_module = torch.get_device_module(device_type) + device = torch.device(device_type, rank % device_module.device_count()) + + cp_mesh = torch.distributed.device_mesh.init_device_mesh( + device_type=device_type, + mesh_shape=(ring_degree, ulysses_degree), + mesh_dim_names=("ring", "ulysses"), + ) + parallel_config = ParallelConfig( + rank=rank, + world_size=world_size, + ring_degree=ring_degree, + ulysses_degree=ulysses_degree, + device=device, + cp_mesh=cp_mesh, + ) + cp_plan = cp_plan if cp_plan is not None else self._cp_plan + + apply_context_parallel(self, parallel_config, cp_plan) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 9080cd508de4..95f0f129d956 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -569,6 +570,15 @@ class FluxTransformer2DModel( _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( From 79736265c5f54e3c6bb20d753a8074a49c2089da Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 21:24:24 +0200 Subject: [PATCH 20/25] refactor --- src/diffusers/hooks/context_parallel.py | 8 ++++---- src/diffusers/models/attention_dispatch.py | 22 ++++++++++++++++++---- src/diffusers/models/modeling_utils.py | 4 ++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index c7b88f5b8df0..788d030afa39 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -34,8 +34,8 @@ logger = get_logger(__name__) # pylint: disable=invalid-name _CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook" -_CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE = "cp_input---{}" -_CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" +_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" +_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" # TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata @@ -92,14 +92,14 @@ def apply_context_parallel( for m in submodule: if isinstance(cp_model_plan, dict): hook = ContextParallelSplitHook(cp_model_plan, parallel_config) - hook_name = _CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE.format(module_id) + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): if isinstance(cp_model_plan, ContextParallelOutput): cp_model_plan = [cp_model_plan] if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") hook = ContextParallelGatherHook(cp_model_plan, parallel_config) - hook_name = _CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE.format(module_id) + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) else: raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") registry = HookRegistry.check_if_exists_or_initialize(m) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 483efc130345..07a9c09b5eed 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -751,6 +751,11 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + scale: Optional[float], + is_causal: bool, + enable_gqa: bool, return_lse: bool, op: torch.autograd.Function, ): @@ -773,7 +778,7 @@ def forward( value = kv[key.numel() :].reshape_as(value) next_rank = (next_rank + 1) % world_size - out, lse = op.apply(query, key, value, None, 0.0, None, False, False, True) + out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True) if parallel_config.convert_to_fp32: out = out.to(torch.float32) @@ -806,6 +811,11 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + scale: Optional[float], + is_causal: bool, + enable_gqa: bool, return_lse: bool, op: torch.autograd.Function, ): @@ -823,7 +833,7 @@ def forward( query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) - out = op.apply(query, key, value, None, 0.0, None, False, False, return_lse) + out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) if return_lse: out, lse, *_ = out @@ -872,9 +882,13 @@ def _templated_context_parallel_attention( parallel_config = _AttentionBackendRegistry._parallel_config # TODO: add support for unified attention with ring/ulysses degree both being > 1 if parallel_config.ring_degree > 1: - return TemplatedRingAttention.apply(query, key, value, return_lse, op) + return TemplatedRingAttention.apply( + query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op + ) elif parallel_config.ulysses_degree > 1: - return TemplatedUlyssesAttention.apply(query, key, value, return_lse, op) + return TemplatedUlyssesAttention.apply( + query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op + ) else: return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 8dba8e45f5a7..19e2f1662bfe 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1535,6 +1535,10 @@ def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan= device=device, cp_mesh=cp_mesh, ) + if cp_plan is None and self._cp_plan is None: + raise ValueError( + "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." + ) cp_plan = cp_plan if cp_plan is not None else self._cp_plan apply_context_parallel(self, parallel_config, cp_plan) From f859fdf7ba742e66c1e4f809cfebda59cd94aa64 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 21:31:53 +0200 Subject: [PATCH 21/25] refactor; support flash attention 2 with cp --- src/diffusers/models/attention_dispatch.py | 53 ++++++++++++++-------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 07a9c09b5eed..6357bdad0f92 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -571,8 +571,8 @@ def forward( value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, ): @@ -653,8 +653,8 @@ def forward( value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, ): @@ -753,8 +753,8 @@ def forward( value: torch.Tensor, attn_mask: Optional[torch.Tensor], dropout_p: float, - scale: Optional[float], is_causal: bool, + scale: Optional[float], enable_gqa: bool, return_lse: bool, op: torch.autograd.Function, @@ -778,7 +778,7 @@ def forward( value = kv[key.numel() :].reshape_as(value) next_rank = (next_rank + 1) % world_size - out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True) + out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True) if parallel_config.convert_to_fp32: out = out.to(torch.float32) @@ -813,8 +813,8 @@ def forward( value: torch.Tensor, attn_mask: Optional[torch.Tensor], dropout_p: float, - scale: Optional[float], is_causal: bool, + scale: Optional[float], enable_gqa: bool, return_lse: bool, op: torch.autograd.Function, @@ -833,7 +833,7 @@ def forward( query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) - out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) + out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) if return_lse: out, lse, *_ = out @@ -883,14 +883,14 @@ def _templated_context_parallel_attention( # TODO: add support for unified attention with ring/ulysses degree both being > 1 if parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( - query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op ) elif parallel_config.ulysses_degree > 1: return TemplatedUlyssesAttention.apply( - query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op ) else: - return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse) + return op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) # ===== Attention backends ===== @@ -905,20 +905,33 @@ def _flash_attention( key: torch.Tensor, value: torch.Tensor, dropout_p: float = 0.0, - scale: Optional[float] = None, is_causal: bool = False, + scale: Optional[float] = None, return_lse: bool = False, ) -> torch.Tensor: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - return out + parallel_config = _AttentionBackendRegistry._parallel_config + + lse = None + if parallel_config is None: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2 + ) + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( From e76fc948b0df0cc0357a4ab3c1cf4b2991f3377d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 21:41:09 +0200 Subject: [PATCH 22/25] fix --- src/diffusers/models/attention_dispatch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6357bdad0f92..8d149cb7a216 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -899,6 +899,7 @@ def _templated_context_parallel_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _flash_attention( query: torch.Tensor, From 171152f2757d868a40bc2dd0737c279986a1d2c5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 21:54:11 +0200 Subject: [PATCH 23/25] support sage attention with cp --- src/diffusers/models/attention_dispatch.py | 94 ++++++++++++++++++---- 1 file changed, 77 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8d149cb7a216..9323c45acbec 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -556,7 +556,7 @@ def _( # ===== Autograd functions ===== -class _cudnn_attention(torch.autograd.Function): +class _cudnn_attention_af(torch.autograd.Function): # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -614,7 +614,7 @@ def forward( def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, - *args: torch.Tensor, + *args, ): query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_out = grad_out.transpose(1, 2).contiguous() @@ -644,7 +644,7 @@ def backward( # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 -class _flash_attention_2(torch.autograd.Function): +class _flash_attention_2_af(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, @@ -707,7 +707,7 @@ def forward( def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, - *args: torch.Tensor, + *args, ): query, key, value, out, lse, rng_state = ctx.saved_tensors grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) @@ -741,6 +741,51 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None +class _sage_attention_af(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + ): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for Sage attention.") + if dropout_p > 0.0: + raise ValueError("`dropout_p` is not yet supported for Sage attention.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for Sage attention.") + + out = sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + lse = None + if return_lse: + out, lse, *_ = out + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass is not implemented for Sage attention.") + + # ===== Context parallel ===== @@ -799,7 +844,7 @@ def forward( def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, - *args: torch.Tensor, + *args, ): raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.") @@ -854,7 +899,7 @@ def forward( def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, - *args: torch.Tensor, + *args, ): raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.") @@ -927,7 +972,7 @@ def _flash_attention( out, lse, *_ = out else: out = _templated_context_parallel_attention( - query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2 + query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2_af ) if return_lse: out, lse = out @@ -1191,7 +1236,7 @@ def _native_cudnn_attention( out = out.permute(0, 2, 1, 3) else: out = _templated_context_parallel_attention( - query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention_af ) if return_lse: out, lse = out @@ -1356,6 +1401,7 @@ def _native_xla_attention( @_AttentionBackendRegistry.register( AttentionBackendName.SAGE, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _sage_attention( query: torch.Tensor, @@ -1365,15 +1411,29 @@ def _sage_attention( scale: Optional[float] = None, return_lse: bool = False, ) -> torch.Tensor: - return sageattn( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) + parallel_config = _AttentionBackendRegistry._parallel_config + + lse = None + if parallel_config is None: + out = sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, key, value, None, 0.0, is_causal, scale, False, return_lse, op=_sage_attention_af + ) + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( From 62f164d04dd6be6e0c2d1066aeb2f5f29ca5bcae Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 23:55:20 +0200 Subject: [PATCH 24/25] make torch compile compatible --- src/diffusers/hooks/context_parallel.py | 50 ++++++++++++++++------ src/diffusers/models/attention_dispatch.py | 22 +++++++--- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 788d030afa39..c3697f967dc8 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -105,19 +105,41 @@ def apply_context_parallel( registry = HookRegistry.check_if_exists_or_initialize(m) registry.register_hook(hook, hook_name) - registry = HookRegistry.check_if_exists_or_initialize(module) - hook = ContextParallelModelHook(parallel_config) - registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK) + # HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward + # diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks + # available in pytorch to set the parallel context before/after the forward/backward pass. + # It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet. + # The previous/older implementation simply did this: + # def new_forward(self, ...): + # with _parallel_context(parallel_config): + # return self.fn_ref.original_forward(*args, **kwargs) + # TODO: ask help from Pytorch team on how to improve this + @torch.compiler.disable + def forward_pre_hook(module, args): + module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) + module._diffusers_parallel_config_setter_context.__enter__() + @torch.compiler.disable + def forward_hook(module, args, output): + if module._diffusers_parallel_config_setter_context is not None: + module._diffusers_parallel_config_setter_context.__exit__(None, None, None) + module._diffusers_parallel_config_setter_context = None -class ContextParallelModelHook(ModelHook): - def __init__(self, parallel_config: ParallelConfig) -> None: - super().__init__() - self.parallel_config = parallel_config + @torch.compiler.disable + def backward_pre_hook(module, grad_output): + module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) + module._diffusers_parallel_config_setter_context.__enter__() - def new_forward(self, module: torch.nn.Module, *args, **kwargs): - with _parallel_context(self.parallel_config): - return self.fn_ref.original_forward(*args, **kwargs) + @torch.compiler.disable + def backward_hook(module, grad_output, grad_input): + if module._diffusers_parallel_config_setter_context is not None: + module._diffusers_parallel_config_setter_context.__exit__(None, None, None) + module._diffusers_parallel_config_setter_context = None + + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_hook) + module.register_full_backward_pre_hook(backward_pre_hook) + module.register_full_backward_hook(backward_hook) class ContextParallelSplitHook(ModelHook): @@ -234,13 +256,15 @@ def post_forward(self, module, output): class EquipartitionSharder: @classmethod - @torch.compiler.disable def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: assert tensor.size()[dim] % mesh.size() == 0 - return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] + + # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) + # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] + + return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] @classmethod - @torch.compiler.disable def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: tensor = tensor.contiguous() tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group()) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 9323c45acbec..28837c06b812 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -21,6 +21,7 @@ import torch import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor from ..utils import ( get_logger, @@ -245,9 +246,6 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV @contextlib.contextmanager def _parallel_context(parallel_config: "ParallelConfig"): - """ - Context manager to set the parallel configuration for attention backends that support it. - """ old_parallel_config = _AttentionBackendRegistry._parallel_config _AttentionBackendRegistry._parallel_config = parallel_config @@ -789,6 +787,16 @@ def backward( # ===== Context parallel ===== +# Reference: +# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827 +# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246 +# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method): +def _wait_tensor(tensor): + if isinstance(tensor, funcol.AsyncCollectiveTensor): + tensor = tensor.wait() + return tensor + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -875,7 +883,9 @@ def forward( x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() for x in (query, key, value) ) - query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value)) + query, key, value = ( + _wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) for x in (query, key, value) + ) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse) @@ -883,12 +893,12 @@ def forward( out, lse, *_ = out out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() - out = funcol.all_to_all_single(out, None, None, group=group).wait() + out = _wait_tensor(funcol.all_to_all_single(out, None, None, group=group)) out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() if return_lse: lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() - lse = funcol.all_to_all_single(lse, None, None, group=group).wait() + lse = _wait_tensor(funcol.all_to_all_single(lse, None, None, group=group)) lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() else: lse = None From 26a5a5c9b0fc36cde3468e6254f6ff5bab2cb19f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 17 Jul 2025 14:08:51 +0200 Subject: [PATCH 25/25] update --- src/diffusers/models/_modeling_parallel.py | 4 +++- src/diffusers/models/attention_dispatch.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 4fa7df3c47a2..7df474f04bcd 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -1,4 +1,6 @@ -# Experimental parallelism support for Diffusers. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ Experimental parallelism support for Diffusers ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# Experimental changes are subject to change and APIs may break without warning. + # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 28837c06b812..d4d01f082f38 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -21,7 +21,6 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed.tensor from ..utils import ( get_logger,