Skip to content

xlite-dev/flux-faster

Repository files navigation

🤗 flux-faster

A forked version of huggingface/flux-fast that makes flux-fast even faster with cache-dit, 3.3x speedup on NVIDIA L20 while still maintaining high precision.

pip3 install -U cache-dit # or: pip3 install git+https://github.com/vipshop/cache-dit.git

As you can see, under the configuration of cache-dit + F1B0 + no warmup + TaylorSeer, it only takes 7.42 seconds on NVIDIA L20, with a cumulative speedup of 3.36x (compared to the baseline of 24.94 seconds), while still maintaining high precision with a PSNR of 23.23.

FLUX.1-dev 28 steps, Baseline: BF16 + w/o torch.compile + w/o cache-dit BF16 + cache-dit + F12B12 + warmup 8 steps BF16 + cache-dit + F12B12 + warmup 8 steps + compile
PSNR: inf PSNR: 34.23 PSNR: 34.16
L20: 24.94s L20: 20.85s L20: 17.39s
output output_cache output_cache_compile
FLUX.1-dev 28 steps, Baseline: BF16 + w/o torch.compile + w/o cache-dit BF16 + compile + qkv projection + channels_last + float8 quant + inductor flags BF16 + compile + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F1B0 + no warmup + TaylorSeer
PSNR: inf PSNR: 21.77 PSNR: 23.23
L20: 24.94s L20: 13.26s L20: 7.42s
output bf16_compile_qkv_chan_quant_flags_trn bf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn
More results
FLUX.1-dev 28 steps, Baseline: BF16 + w/o torch.compile + w/o cache-dit BF16 + cache-dit + F12B12 + warmup 8 steps BF16 + cache-dit + F12B12 + warmup 8 steps + compile
Baseline (FLUX.1-dev 28 steps) PSNR: 34.23 PSNR: 34.16
L20: 24.94s L20: 20.85s L20: 17.39s
output output_cache output_cache_compile
BF16 + compile BF16 + compile + qkv projection + channels_last + float8 quant + inductor flags BF16 + compile + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F12B12 + warmup 8 steps
PSNR: 19.28 PSNR: 18.07 PSNR: 22.24
L20: 20.24s L20: 13.29s L20: 11.21s
bf16_compile bf16_compile_qkv_chan_quant_flags bf16_cache_compile_qkv_chan_quant_flags
BF16 + compile transformer block only BF16 + compile transformer block only + qkv projection + channels_last + float8 quant + inductor flags BF16 + compile transformer block only + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F12B12 + warmup 8 steps
PSNR: 39.72 PSNR: 21.77 PSNR: 21.89
L20: 20.49s L20: 13.26s L20: 11.14s
bf16_compile_trn bf16_compile_qkv_chan_quant_flags_trn bf16_cache_compile_qkv_chan_quant_flags_trn
BF16 + compile transformer block only + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F8B0 + no warmup BF16 + compile transformer block only + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F1B0 + no warmup BF16 + compile transformer block only + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F1B0 + no warmup + TaylorSeer
PSNR: 21.82 PSNR: 20.93 PSNR: 23.23
L20: 8.98s L20: 7.41s L20: 7.42s
bf16_cache_F8B0W0M0_compile_qkv_chan_quant_flags_trn bf16_cache_F1B0W0M0_compile_qkv_chan_quant_flags_trn bf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn

Important Notes

  1. Please add --enable_cache_dit flag to use cache-dit. cache-dit doesn't work with torch.export now. cache-dit extends Flux and introduces some Python dynamic operations, so it may not be possible to export the model using torch.export.
  2. Compiling the entire transformer appears to introduce precision loss in my tests on an NVIDIA L20 device (tested with PyTorch 2.7.1). Please try to add --only_compile_transformer_blocks flag to compile transformer blocks only if you want to keep higer precision.

Experiments

Please run experiments_cache.sh script to reproduce the results. For example:

# bfloat16 + only compile transformer blocks + qkv projection + channels_last + float8 quant + inductor flags 
# + cache: F1B0 + no warmup steps + no limit cached steps + TaylorSeer
python run_benchmark.py \
    --ckpt black-forest-labs/FLUX.1-dev \
    --trace-file bf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn.json.gz \
    --compile_export_mode compile \
    --only_compile_transformer_blocks \
    --disable_fa3 \
    --num_inference_steps 28 \
    --enable_cache_dit \
    --Fn 1 --Bn 0 \
    --warmup_steps 0 \
    --max_cached_steps -1 \
    --enable_taylorseer \
    --output-file bf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn.png \
    > bf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn.txt 2>&1

flux-fast

Making Flux go brrr on GPUs. With simple recipes from this repo, we enabled ~2.5x speedup on Flux.1-Schnell and Flux.1-Dev using (mainly) pure PyTorch code and a beefy GPU like H100. This repo is NOT meant to be a library or an out-of-the-box solution. So, please fork the repo, hack into the code, and share your results 🤗

Check out the accompanying blog post here.

Updates

July 1, 2025: This repository now supports AMD MI300X GPUs using AITER kernels (PR). The README has been updated to provide instructions on how to run on AMD GPUs.

June 28, 2025: This repository now supports Flux.1 Kontext Dev. We enabled ~2.5x speedup on it. Check out this section for more details.

Results

Description Image
Flux.1-Schnell new_flux_schnell_plot
Flux.1-Dev flux_dev_result_plot

Summary of the optimizations:

  • Running with the bfloat16 precision
  • torch.compile
  • Combining q,k,v projections for attention computation
  • torch.channels_last memory format for the decoder output
  • Flash Attention v3 (FA3) with (unscaled) conversion of inputs to torch.float8_e4m3fn
  • Dynamic float8 quantization and quantization of Linear layer weights via torchao's float8_dynamic_activation_float8_weight
  • Inductor flags:
    • conv_1x1_as_mm = True
    • epilogue_fusion = False
    • coordinate_descent_tuning = True
    • coordinate_descent_check_all_directions = True
  • torch.export + Ahead-of-time Inductor (AOTI) + CUDAGraphs
  • cache acceleration with cache-dit: DBCache

All of the above optimizations are lossless (outside of minor numerical differences sometimes introduced through the use of torch.compile / torch.export) EXCEPT FOR dynamic float8 quantization. Disable quantization if you want the same quality results as the baseline while still being quite a bit faster.

Here are some example outputs with Flux.1-Schnell for prompt "A cat playing with a ball of yarn":

Configuration Output
Baseline baseline_output
Fully-optimized (with quantization) fast_output

Setup

We rely primarily on pure PyTorch for the optimizations. Currently, a relatively recent nightly version of PyTorch is required.

The numbers reported here were gathered using:

For NVIDIA:

  • torch==2.8.0.dev20250605+cu126 - note that we rely on some fixes since 2.7
  • torchao==0.12.0.dev20250610+cu126 - note that we rely on a fix in the 06/10 nightly
  • diffusers - with this fix included
  • flash_attn_3==3.0.0b1

For AMD:

  • torch==2.8.0.dev20250605+rocm6.4 - note that we rely on some fixes since 2.7
  • torchao==0.12.0.dev20250610+rocm6.4 - note that we rely on a fix in the 06/10 nightly
  • diffusers - with this fix included
  • aiter-0.1.4.dev17+gd0384d4

To install deps on NVIDIA:

pip install -U huggingface_hub[hf_xet] accelerate transformers
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --pre torchao==0.12.0.dev20250610+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126

(For NVIDIA) To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.

To install deps on AMD:

pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install --pre torchao==0.12.0.dev20250610+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install git+https://github.com/ROCm/aiter

(For AMD) Instead of flash attention v3, we use AITER. It provides the required fp8 MHA kernels

For hardware, we used a 96GB 700W H100 GPU and 192GB MI300X GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.

Run the optimized pipeline

On NVIDIA:

python gen_image.py --prompt "An astronaut standing next to a giant lemon" --output-file output.png --use-cached-model

This will include all optimizations and will attempt to use pre-cached binary models generated via torch.export + AOTI. To generate these binaries for subsequent runs, run the above command without the --use-cached-model flag.

Important

The binaries won't work for hardware that is sufficiently different from the hardware they were obtained on. For example, if the binaries were obtained on an H100, they won't work on A100. Further, the binaries are currently Linux-only and include dependencies on specific versions of system libs such as libstdc++; they will not work if they were generated in a sufficiently different environment than the one present at runtime. The PyTorch Compiler team is working on solutions for more portable binaries / artifact caching.

On AMD:

python gen_image.py --prompt "A cat playing with a ball of yarn" --output-file output.png --compile_export_mode compile

Currently, only torch.export is not working as expected. Instead, use torch.compile as shown in the above command.

Benchmarking

run_benchmark.py is the main script for benchmarking the different optimization techniques. Usage:

usage: run_benchmark.py [-h] [--ckpt CKPT] [--prompt PROMPT] [--image IMAGE] [--cache-dir CACHE_DIR]
                        [--use-cached-model] [--device {cuda,cpu}] [--num_inference_steps NUM_INFERENCE_STEPS] 
                        [--output-file OUTPUT_FILE] [--seed SEED] [--trace-file TRACE_FILE] [--disable_bf16]
                        [--compile_export_mode {compile,export_aoti,disabled}] 
                        [--only_compile_transformer_blocks] [--disable_fused_projections] 
                        [--disable_channels_last] [--disable_fa3] [--disable_quant]
                        [--disable_inductor_tuning_flags] [--enable_cache_dit] 
                        [--Fn_compute_blocks FN_COMPUTE_BLOCKS] 
                        [--Bn_compute_blocks BN_COMPUTE_BLOCKS] 
                        [--warmup_steps WARMUP_STEPS]
                        [--max_cached_steps MAX_CACHED_STEPS] 
                        [--residual_diff_threshold RESIDUAL_DIFF_THRESHOLD] 
                        [--enable_taylorseer]

options:
  -h, --help            show this help message and exit
  --ckpt {black-forest-labs/FLUX.1-schnell,black-forest-labs/FLUX.1-dev,black-forest-labs/FLUX.1-Kontext-dev}
                        Model checkpoint path (default: black-forest-labs/FLUX.1-schnell)
  --prompt PROMPT       Text prompt (default: A cat playing with a ball of yarn)
  --image IMAGE         Image to use for Kontext (default: None)
  --cache-dir CACHE_DIR
                        Cache directory for storing exported models (default: /root/.cache/flux-fast)
  --use-cached-model    Attempt to use cached model only (don't re-export) (default: False)
  --device {cuda,cpu}   Device to use (default: cuda)
  --num_inference_steps NUM_INFERENCE_STEPS
                        Number of denoising steps (default: 4)
  --output-file OUTPUT_FILE
                        Output image file path (default: output.png)
  --seed SEED           Random seed to use (default: 42)
  --trace-file TRACE_FILE
                        Output PyTorch Profiler trace file path (default: None)
  --disable_bf16        Disables usage of torch.bfloat16 (default: False)
  --compile_export_mode {compile,export_aoti,disabled}
                        Configures how torch.compile or torch.export + AOTI are used (default: export_aoti)
  --only_compile_transformer_blocks
                        Only compile Transformer Blocks for higher precision (default: False)
  --disable_fused_projections
                        Disables fused q,k,v projections (default: False)
  --disable_channels_last
                        Disables usage of torch.channels_last memory format (default: False)
  --disable_fa3         Disables use of Flash Attention V3 (default: False)
  --disable_quant       Disables usage of dynamic float8 quantization (default: False)
  --disable_inductor_tuning_flags
                        Disables use of inductor tuning flags (default: False)
  --enable_cache_dit    Enables use of cache-dit: DBCache (default: False)
  --Fn_compute_blocks FN_COMPUTE_BLOCKS, --Fn FN_COMPUTE_BLOCKS
                        Fn compute blocks of cache-dit: DBCache (default: 1)
  --Bn_compute_blocks BN_COMPUTE_BLOCKS, --Bn BN_COMPUTE_BLOCKS
                        Bn compute blocks of cache-dit: DBCache (default: 0)
  --warmup_steps WARMUP_STEPS
                        Warmup steps of cache-dit: DBCache (default: 0)
  --max_cached_steps MAX_CACHED_STEPS
                        Max Cached steps of cache-dit: DBCache (default: -1)
  --residual_diff_threshold RESIDUAL_DIFF_THRESHOLD
                        Residual diff threshold of cache-dit: DBCache (default: 0.12)
  --enable_taylorseer   Enables use of cache-dit: DBCache with TaylorSeer (default: False)

Note that all optimizations are on by default and each can be individually toggled. Example run:

# Run with all optimizations and output a trace file alongside benchmark numbers
python run_benchmark.py --trace-file profiler_trace.json.gz

After an experiment has been run, you should expect to see mean / variance times in seconds for 10 benchmarking runs printed to STDOUT, as well as:

  • A .png image file corresponding to the experiment (e.g. output.png). The path can be configured via --output-file.
  • An optional PyTorch profiler trace (e.g. profiler_trace.json.gz). The path can be configured via --trace-file

Important

For benchmarking purposes, we use reasonable defaults. For example, for all the benchmarking experiments, we use the 1024x1024 resolution. For Schnell, we use 4 denoising steps, and for Dev and Kontext, we use 28.

Flux.1 Kontext Dev

We ran the exact same setup as above on Flux.1 Kontext Dev and obtained the following result:

flux_kontext_plot

Here are some example outputs for prompt "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" and this image:

Configuration Output
Baseline baseline_output
Fully-optimized (with quantization) fast_output
Notes
  • You need to install diffusers with this fix included
  • You need to install torchao with this fix included

Improvements, progressively

Baseline

For completeness, we demonstrate a (terrible) baseline here using the default torch.float32 dtype. There's no practical reason do this over loading in torch.bfloat16, and the results are slow enough that they ruin the readability of the graph above when included (~7.5 sec).

from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
BFloat16
from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
torch.compile
from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Compile the compute-intensive portions of the model: denoising transformer / decoder
# "max-autotune" mode tunes kernel hyperparameters and applies CUDAGraphs
pipeline.transformer = torch.compile(
    pipeline.transformer, mode="max-autotune", fullgraph=True
)
pipeline.vae.decode = torch.compile(
    pipeline.vae.decode, mode="max-autotune", fullgraph=True
)

# warmup for a few iterations; trigger compilation
for _ in range(3):
    pipeline(
        "dummy prompt to trigger torch compilation",
        output_type="pil",
        num_inference_steps=4,
    ).images[0]

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
Combining attention projection matrices
from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Use channels_last memory format
pipeline.vae = pipeline.vae.to(memory_format=torch.channels_last)

# Combine attention projection matrices for (q, k, v)
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()

# compilation details omitted (see above)
...

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]

Note that torch.compile is able to perform this fusion automatically, so we do not observe a speedup from the fusion (outside of noise) when torch.compile is enabled.

channels_last memory format
from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Use channels_last memory format
pipeline.vae.to(memory_format=torch.channels_last)

# compilation details omitted (see above)
...

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
Flash Attention V3 / aiter

Flash Attention V3 is substantially faster on H100s than the previous iteration FA2, due in large part to float8 support. As this kernel isn't quite available yet within PyTorch Core, we implement a custom attention processor FlashFusedFluxAttnProcessor3_0 that uses the flash_attn_interface python bindings directly. We also ensure proper PyTorch custom op integration so that the op integrates well with torch.compile / torch.export. Inputs are converted to float8 in an unscaled fashion before kernel invocation and outputs are converted back to the original dtype on the way out.

On AMD GPUs, we use aiter instead, which also provides fp8 MHA kernels.

from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Use channels_last memory format
pipeline.vae.to(memory_format=torch.channels_last)

# Combine attention projection matrices for (q, k, v)
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()

# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for details
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())

# compilation details omitted (see above)
...

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
float8 quantization
from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Use channels_last memory format
pipeline.vae.to(memory_format=torch.channels_last)

# Combine attention projection matrices for (q, k, v)
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()

# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for details
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())

# Apply float8 quantization on weights and activations
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight

quantize_(
    pipeline.transformer,
    float8_dynamic_activation_float8_weight(),
)

# compilation details omitted (see above)
...

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
Inductor tuning flags
from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Use channels_last memory format
pipeline.vae.to(memory_format=torch.channels_last)

# Combine attention projection matrices for (q, k, v)
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()

# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for details
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())

# Apply float8 quantization on weights and activations
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight

quantize_(
    pipeline.transformer,
    float8_dynamic_activation_float8_weight(),
)

# Tune Inductor flags
config = torch._inductor.config
config.conv_1x1_as_mm = True  # treat 1x1 convolutions as matrix muls
# adjust autotuning algorithm
config.coordinate_descent_tuning = True
config.coordinate_descent_check_all_directions = True
config.epilogue_fusion = False  # do not fuse pointwise ops into matmuls

# compilation details omitted (see above)
...

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
torch.export + Ahead-Of-Time Inductor (AOTI)

To avoid initial compilation times, we can use torch.export + Ahead-Of-Time Inductor (AOTI). This will serialize a binary, precompiled form of the model without initial compilation overhead.

# Apply torch.export + AOTI. If serialize=True, writes out the exported models within the cache_dir.
# Otherwise, attempts to load previously-exported models from the cache_dir.
# This function also applies CUDAGraphs on the loaded models.
def use_export_aoti(pipeline, cache_dir, serialize=False):
    from torch._inductor.package import load_package

    # create cache dir if needed
    pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)

    def _example_tensor(*shape):
        return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)

    # === Transformer export ===
    # torch.export requires a representative set of example args to be passed in
    transformer_kwargs = {
        "hidden_states": _example_tensor(1, 4096, 64),
        "timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
        "guidance": None,
        "pooled_projections": _example_tensor(1, 768),
        "encoder_hidden_states": _example_tensor(1, 512, 4096),
        "txt_ids": _example_tensor(512, 3),
        "img_ids": _example_tensor(4096, 3),
        "joint_attention_kwargs": {},
        "return_dict": False,
    }

    # Possibly serialize model out
    transformer_package_path = os.path.join(cache_dir, "exported_transformer.pt2")
    if serialize:
        # Apply export
        exported_transformer: torch.export.ExportedProgram = torch.export.export(
            pipeline.transformer, args=(), kwargs=transformer_kwargs
        )

        # Apply AOTI
        path = torch._inductor.aoti_compile_and_package(
            exported_transformer,
            package_path=transformer_package_path,
            inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
        )

    loaded_transformer = load_package(
        transformer_package_path, run_single_threaded=True
    )

    # warmup before cudagraphing
    with torch.no_grad():
        loaded_transformer(**transformer_kwargs)

    # Apply CUDAGraphs. CUDAGraphs are utilized in torch.compile with mode="max-autotune", but
    # they must be manually applied for torch.export + AOTI.
    loaded_transformer = cudagraph(loaded_transformer)
    pipeline.transformer.forward = loaded_transformer

    # warmup after cudagraphing
    with torch.no_grad():
        pipeline.transformer(**transformer_kwargs)

    # hack to get around export's limitations
    pipeline.vae.forward = pipeline.vae.decode

    vae_decode_kwargs = {
        "return_dict": False,
    }

    # Possibly serialize model out
    decoder_package_path = os.path.join(cache_dir, "exported_decoder.pt2")
    if serialize:
        # Apply export
        exported_decoder: torch.export.ExportedProgram = torch.export.export(
            pipeline.vae, args=(_example_tensor(1, 16, 128, 128),), kwargs=vae_decode_kwargs
        )

        # Apply AOTI
        path = torch._inductor.aoti_compile_and_package(
            exported_decoder,
            package_path=decoder_package_path,
            inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
        )

    loaded_decoder = load_package(decoder_package_path, run_single_threaded=True)

    # warmup before cudagraphing
    with torch.no_grad():
        loaded_decoder(_example_tensor(1, 16, 128, 128), **vae_decode_kwargs)

    loaded_decoder = cudagraph(loaded_decoder)
    pipeline.vae.decode = loaded_decoder

    # warmup for a few iterations
    for _ in range(3):
        pipeline(
            "dummy prompt to trigger torch compilation",
            output_type="pil",
            num_inference_steps=4,
        ).images[0]

    return pipeline

Note that, unlike for torch.compile, running a model loaded from the torch.export + AOTI workflow doesn't use CUDAGraphs by default. This was found to result in a ~5% performance decrease vs. torch.compile. To address this discrepancy, we manually record / replay CUDAGraphs over the exported models using the following helper:

# wrapper to automatically handle CUDAGraph record / replay over the given function
def cudagraph(f):
    from torch.utils._pytree import tree_map_only

    _graphs = {}
    def f_(*args, **kwargs):
        key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys())
                         if isinstance(kwargs[a], torch.Tensor)))
        if key in _graphs:
            # use the cached wrapper if one exists. this will perform CUDAGraph replay
            wrapped, *_ = _graphs[key]
            return wrapped(*args, **kwargs)

        # record a new CUDAGraph and cache it for future use
        g = torch.cuda.CUDAGraph()
        in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
        f(*in_args, **in_kwargs) # stream warmup
        with torch.cuda.graph(g):
            out_tensors = f(*in_args, **in_kwargs)
        def wrapped(*args, **kwargs):
            # note that CUDAGraphs require inputs / outputs to be in fixed memory locations.
            # inputs must be copied into the fixed input memory locations.
            [a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)]
            for key in kwargs:
                if isinstance(kwargs[key], torch.Tensor):
                    in_kwargs[key].copy_(kwargs[key])
            g.replay()
            # clone() outputs on the way out to disconnect them from the fixed output memory
            # locations. this allows for CUDAGraph reuse without accidentally overwriting memory
            return [o.clone() for o in out_tensors]

        # cache function that does CUDAGraph replay
        _graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors)
        return wrapped(*args, **kwargs)
    return f_

Finally, here is the fully-optimized form of the model:

from diffusers import FluxPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell"
).to("cuda")

# Use channels_last memory format
pipeline.vae.to(memory_format=torch.channels_last)

# Combine attention projection matrices for (q, k, v)
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()

# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for details
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())

# Apply float8 quantization on weights and activations
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight

quantize_(
    pipeline.transformer,
    float8_dynamic_activation_float8_weight(),
)

# Tune Inductor flags
config = torch._inductor.config
config.conv_1x1_as_mm = True  # treat 1x1 convolutions as matrix muls
# adjust autotuning algorithm
config.coordinate_descent_tuning = True
config.coordinate_descent_check_all_directions = True
config.epilogue_fusion = False  # do not fuse pointwise ops into matmuls

# Apply torch.export + AOTI with CUDAGraphs
pipeline = use_export_aoti(pipeline, cache_dir=args.cache_dir, serialize=False)

prompt = "A cat playing with a ball of yarn"
image = pipe(prompt, num_inference_steps=4).images[0]
cache acceleration with cache-dit: DBCache

You can use cache-dit to further speedup FLUX model, different configurations of compute blocks (F12B12, etc.) can be customized in cache-dit: DBCache. Please check cache-dit for more details. For example:

# Install: pip install -U cache-dit
from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

# cache-dit: DBCache configs
cache_options = {
    "cache_type": CacheType.DBCache,
    "warmup_steps": 0,
    "max_cached_steps": -1,  # -1 means no limit
    "Fn_compute_blocks": 1,  # Fn, F1, F12, etc.
    "Bn_compute_blocks": 0,  # Bn, B0, B12, etc.
    "residual_diff_threshold": 0.12,
    # TaylorSeer options
    "enable_taylorseer": True,
    "enable_encoder_taylorseer": True,
    # Taylorseer cache type cache be hidden_states or residual
    "taylorseer_cache_type": "residual",
    "taylorseer_kwargs": {
         "n_derivatives": 2,
    },
}

apply_cache_on_pipe(pipeline, **cache_options)

By the way, cache-dit is designed to work compatibly with torch.compile. You can easily use cache-dit with torch.compile to further achieve a better performance. For example:

apply_cache_on_pipe(pipeline, **cache_options)

# The cache-dit relies heavily on dynamic Python operations to maintain the cache_context, 
# so it is necessary to introduce graph breaks at appropriate positions to be compatible 
# with torch.compile. Thus, we compile the transformer with `max-autotune-no-cudagraphs` 
# mode if cache-dit is enabled. Otherwise, we compile with `max-autotune` mode.
pipeline.transformer = torch.compile(
    pipeline.transformer, 
    mode="max-autotune-no-cudagraphs", 
    fullgraph=False, 
)

As you can see, under the configuration of cache-dit + F1B0 + no warmup + TaylorSeer, it only takes 7.42 seconds on NVIDIA L20, with a cumulative speedup of 3.36x (compared to the baseline of 24.94 seconds), while still maintaining high precision with a PSNR of 23.23.

FLUX.1-dev 28 steps, Baseline: BF16 + w/o torch.compile + w/o cache-dit BF16 + compile + qkv projection + channels_last + float8 quant + inductor flags BF16 + compile + qkv projection + channels_last + float8 quant + inductor flags + cache-dit + F1B0 + no warmup + TaylorSeer
PSNR: inf PSNR: 21.77 PSNR: 23.23
L20: 24.94s L20: 13.26s L20: 7.42s
output bf16_compile_qkv_chan_quant_flags_trn bf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn

Important Notes

  1. Please add --enable_cache_dit flag to use cache-dit. cache-dit doesn't work with torch.export now. cache-dit extends Flux and introduces some Python dynamic operations, so it may not be possible to export the model using torch.export.
  2. Compiling the entire transformer appears to introduce precision loss in my tests on an NVIDIA L20 device (tested with PyTorch 2.7.1). Please try to add --only_compile_transformer_blocks flag to compile transformer blocks only if you want to keep higer precision.