From 6c15a1005ece7f39bfcf0c65c4e7b133300b627b Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Jul 2025 13:37:20 +0200 Subject: [PATCH 1/3] update --- src/diffusers/models/attention_dispatch.py | 98 ++++++++++++++++++---- src/diffusers/models/modeling_utils.py | 4 + 2 files changed, 85 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 141a7fee858b..3ee320d09148 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -38,18 +38,29 @@ from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -logger = get_logger(__name__) # pylint: disable=invalid-name - - -if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): +_REQUIRED_FLASH_VERSION = "2.6.3" +_REQUIRED_SAGE_VERSION = "2.1.1" +_REQUIRED_FLEX_VERSION = "2.5.0" +_REQUIRED_XLA_VERSION = "2.2" +_REQUIRED_XFORMERS_VERSION = "0.0.29" + +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) +_CAN_USE_NPU_ATTN = is_torch_npu_available() +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) + + +if _CAN_USE_FLASH_ATTN: from flash_attn import flash_attn_func, flash_attn_varlen_func else: - logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") flash_attn_func = None flash_attn_varlen_func = None -if is_flash_attn_3_available(): +if _CAN_USE_FLASH_ATTN_3: from flash_attn_interface import flash_attn_func as flash_attn_3_func from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func else: @@ -57,7 +68,7 @@ flash_attn_3_varlen_func = None -if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): +if _CAN_USE_SAGE_ATTN: from sageattention import ( sageattn, sageattn_qk_int8_pv_fp8_cuda, @@ -67,9 +78,6 @@ sageattn_varlen, ) else: - logger.warning( - "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." - ) sageattn = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None @@ -78,39 +86,39 @@ sageattn_varlen = None -if is_torch_version(">=", "2.5.0"): +if _CAN_USE_FLEX_ATTN: # We cannot import the flex_attention function from the package directly because it is expected (from the # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the # compiled function. import torch.nn.attention.flex_attention as flex_attention -if is_torch_npu_available(): +if _CAN_USE_NPU_ATTN: from torch_npu import npu_fusion_attention else: npu_fusion_attention = None -if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): +if _CAN_USE_XLA_ATTN: from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention else: xla_flash_attention = None -if is_xformers_available() and is_xformers_version(">=", "0.0.29"): +if _CAN_USE_XFORMERS_ATTN: import xformers.ops as xops else: - logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") xops = None +logger = get_logger(__name__) # pylint: disable=invalid-name + # TODO(aryan): Add support for the following: # - Sage Attention++ # - block sparse, radial and other attention methods # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet - _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] @@ -171,6 +179,7 @@ def decorator(func): @classmethod def get_active_backend(cls): + _check_backend_requirements(cls._active_backend) return cls._active_backend, cls._backends[cls._active_backend] @classmethod @@ -226,9 +235,10 @@ def dispatch_attention_fn( "dropout_p": dropout_p, "is_causal": is_causal, "scale": scale, - "enable_gqa": enable_gqa, **attention_kwargs, } + if is_torch_version(">=", "2.5.0"): + kwargs["enable_gqa"] = enable_gqa if _AttentionBackendRegistry._checks_enabled: removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) @@ -305,6 +315,60 @@ def _check_shape( # ===== Helper functions ===== +# LRU cache is hack to avoid checking the backend requirements multiple times. Maybe not needed +# because CPU is running much farther ahead of the accelerator and this will not be blocking anyway. +@functools.lru_cache(maxsize=16) +def _check_backend_requirements(backend: AttentionBackendName) -> None: + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: + if not _CAN_USE_FLASH_ATTN: + raise RuntimeError( + f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." + ) + + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: + if not _CAN_USE_FLASH_ATTN_3: + raise RuntimeError( + f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." + ) + + elif backend in [ + AttentionBackendName.SAGE, + AttentionBackendName.SAGE_VARLEN, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + ]: + if not _CAN_USE_SAGE_ATTN: + raise RuntimeError( + f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." + ) + + elif backend == AttentionBackendName.FLEX: + if not _CAN_USE_FLEX_ATTN: + raise RuntimeError( + f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." + ) + + elif backend == AttentionBackendName._NATIVE_NPU: + if not _CAN_USE_NPU_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + + elif backend == AttentionBackendName._NATIVE_XLA: + if not _CAN_USE_XLA_ATTN: + raise RuntimeError( + f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." + ) + + elif backend == AttentionBackendName.XFORMERS: + if not _CAN_USE_XFORMERS_ATTN: + raise RuntimeError( + f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." + ) + + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fb01e7e01a1e..d486a34c45c3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -627,6 +627,8 @@ def set_attention_backend(self, backend: str) -> None: # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + backend = backend.lower() available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: @@ -651,6 +653,8 @@ def reset_attention_backend(self) -> None: from .attention import AttentionModuleMixin from .attention_processor import Attention, MochiAttention + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): From e9fd0ca25c934a049f08d0f6a1bb95648e096e66 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Jul 2025 13:49:38 +0200 Subject: [PATCH 2/3] update --- src/diffusers/models/attention_dispatch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 3ee320d09148..c677105b76e3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -179,7 +179,6 @@ def decorator(func): @classmethod def get_active_backend(cls): - _check_backend_requirements(cls._active_backend) return cls._active_backend, cls._backends[cls._active_backend] @classmethod @@ -227,6 +226,8 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + _check_backend_requirements(backend_name) + kwargs = { "query": query, "key": key, From 23e7548a4315f61c7ad882d32018f336ba177c81 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Jul 2025 22:07:39 +0200 Subject: [PATCH 3/3] update --- src/diffusers/models/attention_dispatch.py | 12 +++++------- src/diffusers/models/modeling_utils.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index c677105b76e3..c00ec7dd6e41 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -187,13 +187,16 @@ def list_backends(cls): @contextlib.contextmanager -def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ Context manager to set the active attention backend. """ if backend not in _AttentionBackendRegistry._backends: raise ValueError(f"Backend {backend} is not registered.") + backend = AttentionBackendName(backend) + _check_attention_backend_requirements(backend) + old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend @@ -226,8 +229,6 @@ def dispatch_attention_fn( backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - _check_backend_requirements(backend_name) - kwargs = { "query": query, "key": key, @@ -316,10 +317,7 @@ def _check_shape( # ===== Helper functions ===== -# LRU cache is hack to avoid checking the backend requirements multiple times. Maybe not needed -# because CPU is running much farther ahead of the accelerator and this will not be blocking anyway. -@functools.lru_cache(maxsize=16) -def _check_backend_requirements(backend: AttentionBackendName) -> None: +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: if not _CAN_USE_FLASH_ATTN: raise RuntimeError( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d486a34c45c3..4941b6d2a7b5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -622,7 +622,7 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName + from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -633,10 +633,10 @@ def set_attention_backend(self, backend: str) -> None: available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) - backend = AttentionBackendName(backend) - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + _check_attention_backend_requirements(backend) + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue