Skip to content

Conversation

ParagEkbote
Copy link
Contributor

@ParagEkbote ParagEkbote commented Sep 11, 2025

What does this PR do?

As discussed in the issue, this PR adds support for kernels-community/flash-attn kernel. Could you please review?

Fixes #12308

Before submitting

Who can review?

@sayakpaul

@sayakpaul
Copy link
Member

Thanks for this PR. Could you update it with some code examples and results?

@ParagEkbote
Copy link
Contributor Author

ParagEkbote commented Sep 11, 2025

This is the test command, but unable to generate images.

import os
os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes"

# Debug: Verify the env var is set
print(f"DIFFUSERS_ENABLE_HUB_KERNELS = {os.environ.get('DIFFUSERS_ENABLE_HUB_KERNELS')}")

import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig

# Debug: Check if diffusers sees the env var
from diffusers.models.attention_dispatch import DIFFUSERS_ENABLE_HUB_KERNELS
print(f"Diffusers sees DIFFUSERS_ENABLE_HUB_KERNELS = {DIFFUSERS_ENABLE_HUB_KERNELS}")

# ✅ 3. Load pipeline with quantization
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    quantization_config=PipelineQuantizationConfig(
        quant_backend="bitsandbytes_4bit",
        quant_kwargs={
            "load_in_4bit": True,
            "bnb_4bit_quant_type": "nf4",
            "bnb_4bit_compute_dtype": torch.bfloat16,
        },
        components_to_quantize=["transformer"],
    ),
).to("cuda")

pipe.transformer.set_attention_backend("_flash_hub")

prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")

@ParagEkbote
Copy link
Contributor Author

I'm having issues regarding some of the parameters with the following traceback:

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/diffusers/main.py", line 34, in <module>
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 944, in __call__
    noise_pred = self.transformer(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
    attention_outputs = self.attn(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
    hidden_states = dispatch_attention_fn(
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 304, in dispatch_attention_fn
    return backend_fn(**kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 765, in _flash_attention_hub
    out = flash_attn_func_hub(
TypeError: flash_attn_func() got an unexpected keyword argument 'alibi_slopes'

The same error occurs with dropout_p parameter as well. WDYT?

cc: @sayakpaul

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support flash-attn kernel support for non-Hopper GPUs
2 participants