Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Layerwise dynamic upcasting to Diffusers Models #9177

Closed
wants to merge 9 commits into from

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Aug 14, 2024

What does this PR do?

Proposal to add the option to enable layerwise dynamic upcasting in order to save memory when running large models such as Flux.

This PR

  1. Adds an enable_layerwise_upcasting method to the ModelMixin class. This method recursively adds pre forward and forward hooks to each module in a model so that we can load/keep the model is a low memory dtype on GPU e.g. torch.float8_e4m3fn and upcast only before the forward method of a module is called. I believe Comfy does something similar.

  2. Introduce a class attribute _always_upcast_modules. In certain cases, modules will use/reference the dtype of a top level module. This means that we need to apply dynamic upcasting to the entire top level module and not just the layers inside it. This attributes lets us know which modules to apply this to.

This approach does bring in considerable memory savings. Here is the VRAM usage for running Flux on a T4 with model offloading

Without layerwise upcasting:
Max Memory GB:  22.743045330047607
----------------------------------------
With layerwise upcasting:
Max Memory GB:  11.697221279144287

TODOs/Considerations:

  1. Not fully set on the name for this method. Technically it's applied to leaf nodes in the compute graphs (ones that have no children). Open to suggestions.
  2. It might not be as relevant if dynamic upcasting can be handled through the quantization approach proposed here (via bitsandbytes): [Quantization] bring quantization to diffusers core #9174. I would prefer that quantization related stuff be handled through a library that is built for it. The advantage with this approach is that we don't need to add any additional dependencies.
  3. In order to get it to work properly, we do have to change a bit how dtypes are assigned inside modules. This isn't a huge issue IMO, but is something to keep in mind every time we add models.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

It might not be as relevant if dynamic upcasting can be handled through the quantization approach proposed here (via bitsandbytes): #9174. I would prefer that quantization related stuff be handled through a library that is built for it. The advantage with this approach is that we don't need to add any additional dependencies.

@DN6, are you suggesting NOT adding any integrations like those in #9174? If so, could you please elaborate your reasoning? It will be in a similar spirit as how it's done in transformers and also how we integrated peft.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks very clean and well thought.

I left some comments but most of them are quite minor in nature.

We should add a doc about it in the memory optimization section: https://huggingface.co/docs/diffusers/optimization/memory. Pinging @stevhliu what would be the best location to add about this technique?

@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""

_always_upcast_modules = ["MaskConditionDecoder"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add _always_upcast_modules to ModelMixin?

@@ -330,7 +330,7 @@ def decode(
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.

"""
z = (z * self.config.scaling_factor - self.means) / self.stds
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this doesn't lead to performance regressions.


"""
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
# Upcast entire module and exist recursion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Upcast entire module and exist recursion
# Upcast entire module and exit recursion

@@ -263,6 +263,80 @@ def disable_xformers_memory_efficient_attention(self) -> None:
"""
self.set_use_memory_efficient_attention_xformers(False)

def enable_layerwise_upcasting(self, upcast_dtype=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with the name because it conveys the idea easily. Maybe we could note down the implementation detail that it is applied to the leaf nodes in the compute graph in the docstring?

@stevhliu
Copy link
Member

The memory optimization doc you mention is perfect @sayakpaul 💯

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Aug 15, 2024

Thanks for this, so cool! I've tried it out with AnimateDiff and here are the results:

Duration
VAE dtype: torch.float16
UNet dtype: torch.float16
memory=3.59
max_memory=3.59
max_reserved=3.67
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:13<00:00,  1.51it/s]
memory=3.60
max_memory=11.60
max_reserved=20.29
Time (denoising + decoding): 16.66

------------

VAE dtype: torch.float8_e5m2
UNet dtype: torch.float8_e5m2
memory=1.91
max_memory=1.91
max_reserved=1.98
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:13<00:00,  1.45it/s]
memory=1.93
max_memory=9.99
max_reserved=18.38
Time (denoising + decoding): 17.61

------------

VAE dtype: torch.float8_e4m3fn
UNet dtype: torch.float8_e4m3fn
memory=1.91
max_memory=1.91
max_reserved=1.94
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:12<00:00,  1.55it/s]
memory=1.92
max_memory=9.98
max_reserved=18.37
Time (denoising + decoding): 18.89
Code
import argparse
import gc
import time

import torch
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter, UNet2DConditionModel
from diffusers.utils import export_to_video


def reset_memory(device):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.reset_accumulated_memory_stats(device)


def print_memory(device):
    memory = torch.cuda.memory_allocated(device) / 1024**3
    max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
    max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
    print(f"{memory=:.2f}")
    print(f"{max_memory=:.2f}")
    print(f"{max_reserved=:.2f}")


def main(load_dtype: torch.dtype, upcast_dtype: torch.dtype = torch.float16, filename: str = "output.mp4", decode_chunk_size: int = 16, device: str = "cuda"):
    motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3", torch_dtype=load_dtype)
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=load_dtype)
    unet = UNet2DConditionModel.from_pretrained("stablediffusionapi/darksushimixv225", subfolder="unet", torch_dtype=load_dtype)

    pipe = AnimateDiffPipeline.from_pretrained("stablediffusionapi/darksushimixv225", motion_adapter=motion_adapter, vae=vae, unet=unet, torch_dtype=torch.float16)
    pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

    pipe.vae.enable_layerwise_upcasting(upcast_dtype)
    pipe.unet.enable_layerwise_upcasting(upcast_dtype)
    pipe.to(device)
    
    reset_memory(device)
    print("VAE dtype:", pipe.vae.dtype)
    print("UNet dtype:", pipe.unet.dtype)
    print_memory(device)

    width = 512
    height = 512
    prompt = "1boy, snowy winter day, trees in the background, snowflakes on leaves, close up"
    negative_prompt = "bad quality, worst quality"
    num_frames = 16
    guidance_scale = 7
    num_inference_steps = 20

    reset_memory(device)
    start_time = time.time()
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_frames=num_frames,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator("cpu").manual_seed(0),
        decode_chunk_size=decode_chunk_size,
    )
    end_time = time.time()

    print_memory(device)
    print(f"Time (denoising + decoding): {end_time - start_time:.2f}")

    frames = output.frames[0]
    export_to_video(frames, filename, fps=8)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--load_dtype", type=str, default="fp8_e5m2")
    parser.add_argument("--upcast_dtype", type=str, default="fp16")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--decode_chunk_size", type=int, default=16)
    return parser.parse_args()


DTYPE_MAPPING = {
    "fp32": torch.float32,
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
    "fp8_e5m2": torch.float8_e5m2,
    "fp8_e4m3": torch.float8_e4m3fn,
}

if __name__ == "__main__":
    args = get_args()

    load_dtype = DTYPE_MAPPING[args.load_dtype]
    upcast_dtype = DTYPE_MAPPING[args.upcast_dtype]
    output_filename = f"animatediff_{args.load_dtype}_{args.upcast_dtype}.mp4"

    main(load_dtype, upcast_dtype, output_filename, args.decode_chunk_size, args.device)

The loaded memory usage is half as expected. The max memory here is due to vae decode. It can be lowered by changing the decode_chunk_size to something like 4 instead of 16, which results in a max memory usage of 4 GB in fp8 and and 5.6 GB for fp16.

The results are different for each floating type though (tried different prompts, seed, etc.), but this is probably expected.

fp16 fp8_e5m2 fp8_e4m3
animatediff_fp16_fp16.webm
animatediff_fp8_e5m2_fp16.webm
animatediff_fp8_e4m3_dp16.webm

I also tried using it with the new FreeNoise branch to test long video generation: animatediff/freenoise-improvements. However, since I was using AnimateLCM loras here, I found that layerwise upcasting does not yet work with it. I'm guessing some additional weight dtype changes might be needed to support loras. Here's code for replication:

Code
import argparse
import gc
import time

import torch
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter, UNet2DConditionModel
from diffusers.utils import export_to_video


def reset_memory(device):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.reset_accumulated_memory_stats(device)


def print_memory(device):
    memory = torch.cuda.memory_allocated(device) / 1024**3
    max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
    max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
    print(f"{memory=:.2f}")
    print(f"{max_memory=:.2f}")
    print(f"{max_reserved=:.2f}")


def main(load_dtype: torch.dtype, upcast_dtype: torch.dtype = torch.float16, filename: str = "output.mp4", device: str = "cuda"):
    motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=load_dtype)
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=load_dtype)
    unet = UNet2DConditionModel.from_pretrained("stablediffusionapi/darksushimixv225", subfolder="unet", torch_dtype=load_dtype)

    pipe = AnimateDiffPipeline.from_pretrained("stablediffusionapi/darksushimixv225", motion_adapter=motion_adapter, vae=vae, unet=unet, torch_dtype=torch.float16)
    pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
    
    pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora")
    pipe.set_adapters(["lcm_lora"], [0.8])

    pipe.enable_free_noise(context_length=16, context_stride=4)
    pipe.vae.enable_layerwise_upcasting(upcast_dtype)
    pipe.unet.enable_layerwise_upcasting(upcast_dtype)
    pipe.to(device)
    
    reset_memory(device)
    print("VAE dtype:", pipe.vae.dtype)
    print("UNet dtype:", pipe.unet.dtype)
    print_memory(device)

    width = 512
    height = 512
    prompt = {
        0: "a woman on a winter day, sparkly leaves in the background, snow flakes, close up",
        80: "a woman on a summer day, trees visible in the background, close up",
        160: "a woman on a autumn day, yellow leaves in the background, close up",
        240: "a woman on a rainy day, tropical leaves in the background, close up",
    }
    negative_prompt = "bad quality, worst quality"
    num_frames = 256
    guidance_scale = 2.5
    num_inference_steps = 10

    reset_memory(device)
    start_time = time.time()
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_frames=num_frames,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator("cpu").manual_seed(0),
    )
    end_time = time.time()

    print_memory(device)
    print(f"Time (denoising + decoding): {end_time - start_time:.2f}")

    frames = output.frames[0]
    export_to_video(frames, filename, fps=16)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--load_dtype", type=str, default="fp8_e5m2")
    parser.add_argument("--upcast_dtype", type=str, default="fp16")
    parser.add_argument("--device", type=str, default="cuda")
    return parser.parse_args()


DTYPE_MAPPING = {
    "fp32": torch.float32,
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
    "fp8_e5m2": torch.float8_e5m2,
    "fp8_e4m3": torch.float8_e4m3fn,
}

if __name__ == "__main__":
    args = get_args()

    load_dtype = DTYPE_MAPPING[args.load_dtype]
    upcast_dtype = DTYPE_MAPPING[args.upcast_dtype]
    output_filename = f"animatediff_{args.load_dtype}_{args.upcast_dtype}_multiprompt.mp4"

    main(load_dtype, upcast_dtype, output_filename, args.device)
Logs
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 16.01it/s]
VAE dtype: torch.float8_e5m2
UNet dtype: torch.float8_e5m2
memory=2.04
max_memory=2.04
max_reserved=2.07
  0%|                                                                                                                                          | 0/10 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/workflows/animatediff_layerwise_upcasting.py", line 104, in <module>
    main(load_dtype, upcast_dtype, output_filename, args.device)
  File "/home/aryan/work/diffusers/workflows/animatediff_layerwise_upcasting.py", line 62, in main
    output = pipe(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff.py", line 815, in __call__
    noise_pred = self.unet(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/unets/unet_motion_model.py", line 2200, in forward
    sample, res_samples = downsample_block(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/unets/unet_motion_model.py", line 543, in forward
    hidden_states = resnet(hidden_states, temb)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/resnet.py", line 341, in forward
    hidden_states = self.conv1(hidden_states)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 1040, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (CUDAFloat8_e5m2Type) and weight type (torch.cuda.HalfTensor) should be the same

@sayakpaul
Copy link
Member

sayakpaul commented Aug 16, 2024

Regarding the error, does it happen without LoRAs too? A simpler and more minimal reproduction would be nice.

The results are different for each floating type though (tried different prompts, seed, etc.), but this is probably expected.

I would say that is expected because we're naively downcasting and upcasting the params here. In a way, we can think of that as some form of compression and decompression. Clever decompression would consider the quantization scale when performing the computations, as well. Libraries like quanto and bitsandbytes already do that when relevant.

Also, the first code snippet seems to be missing.

@a-r-r-o-w
Copy link
Member

Also, the first code snippet seems to be missing.

Oh sorry, edited the comment. Here's the minimal reproducer:

Minimal reproducer
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel

model_id = "runwayml/stable-diffusion-v1-5"

vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float8_e4m3fn)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float8_e4m3fn)

pipe = StableDiffusionPipeline.from_pretrained(model_id, vae=vae, unet=unet, torch_dtype=torch.float16)
pipe.load_lora_weights("sayakpaul/sd-model-finetuned-lora-t4")

pipe.vae.enable_layerwise_upcasting(torch.float16)
pipe.unet.enable_layerwise_upcasting(torch.float16)

pipe.to("cuda")

print("VAE:", pipe.vae.dtype)
print("UNet:", pipe.unet.dtype)

prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
Stack trace
VAE: torch.float8_e4m3fn
UNet: torch.float8_e4m3fn
  0%|                                                                                                                                                        | 0/30 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump5.py", line 73, in <module>
    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 1000, in __call__
    noise_pred = self.unet(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1216, in forward
    sample, res_samples = downsample_block(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 1288, in forward
    hidden_states = attn(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_2d.py", line 442, in forward
    hidden_states = block(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/attention.py", line 466, in forward
    attn_output = self.attn1(
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(
  File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 2191, in __call__
    query = attn.to_q(hidden_states)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 556, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != c10::Half

And if we try fusing the lora:

Logs
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump5.py", line 63, in <module>
    pipe.fuse_lora()
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_pipeline.py", line 475, in fuse_lora
    super().fuse_lora(
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_base.py", line 451, in fuse_lora
    model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 270, in fuse_lora
    self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  [Previous line repeated 5 more times]
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply
    fn(self)
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 292, in _fuse_lora_apply
    module.merge(**merge_kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 453, in merge
    delta_weight = self.get_delta_weight(active_adapter)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 520, in get_delta_weight
    output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
RuntimeError: "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'

@sayakpaul
Copy link
Member

Okay, this is a really good finding. Thank you!

As we can see in the stack trace:

 File "/home/aryan/work/diffusers/src/diffusers/models/attention_processor.py", line 2191, in __call__
    query = attn.to_q(hidden_states)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 556, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != c10::Half

It's rather a problem coming from peft because the intermediate casting is not there (understandably so), unlike this PR. By intermediate casting, I mean this pattern:

- z = (z * self.config.scaling_factor - self.means) / self.stds
+ z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)

And if we try fusing the lora:

What happens if you call fuse_lora() and then call unload_lora_weights() both on pipe and then do the dynamic casting?

@a-r-r-o-w
Copy link
Member

What happens if you call fuse_lora() and then call unload_lora_weights() both on pipe and then do the dynamic casting?

This is what I'm trying:

Code
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel

model_id = "runwayml/stable-diffusion-v1-5"

vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float8_e4m3fn)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float8_e4m3fn)

pipe = StableDiffusionPipeline.from_pretrained(model_id, vae=vae, unet=unet, torch_dtype=torch.float16)
pipe.load_lora_weights("sayakpaul/sd-model-finetuned-lora-t4")
pipe.to("cuda")

pipe.fuse_lora()
pipe.unload_lora_weights()

pipe.vae.enable_layerwise_upcasting(torch.float16)
pipe.unet.enable_layerwise_upcasting(torch.float16)

print("VAE:", pipe.vae.dtype)
print("UNet:", pipe.unet.dtype)

prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
Traceback
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump5.py", line 65, in <module>
    pipe.fuse_lora()
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_pipeline.py", line 475, in fuse_lora
    super().fuse_lora(
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_base.py", line 451, in fuse_lora
    model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 270, in fuse_lora
    self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  [Previous line repeated 5 more times]
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply
    fn(self)
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 292, in _fuse_lora_apply
    module.merge(**merge_kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 453, in merge
    delta_weight = self.get_delta_weight(active_adapter)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 520, in get_delta_weight
    output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
RuntimeError: "addmm_cuda" not implemented for 'Float8_e4m3fn'
(diffusers-work-venv) aryan@hf-dgx-01:~/work/diffusers$ python3 dump5.py 
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.57it/s]
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump5.py", line 65, in <module>
    pipe.fuse_lora()
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_pipeline.py", line 475, in fuse_lora
    super().fuse_lora(
  File "/home/aryan/work/diffusers/src/diffusers/loaders/lora_base.py", line 451, in fuse_lora
    model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 270, in fuse_lora
    self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 894, in apply
    module.apply(fn)
  [Previous line repeated 5 more times]
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply
    fn(self)
  File "/home/aryan/work/diffusers/src/diffusers/loaders/peft.py", line 292, in _fuse_lora_apply
    module.merge(**merge_kwargs)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 453, in merge
    delta_weight = self.get_delta_weight(active_adapter)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 520, in get_delta_weight
    output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
RuntimeError: "addmm_cuda" not implemented for 'Float8_e4m3fn'

This fails in calculating the delta weight to be added to linear layers since float8 matmul is not supported.

@sayakpaul
Copy link
Member

unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float8_e4m3fn)

is the problem as the fuse_lora() step itself is getting errored out here.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Aug 16, 2024

But isn't loading the models in fp8 or lower what we want, and doing the computation in a GPU supported dtype like fp16?

@sayakpaul
Copy link
Member

We do but rightly fuse_lora() cannot be called as we rely on peft for that. Ccing @BenjaminBossan in case of more insights.

@BenjaminBossan
Copy link
Member

I tried to debug the LoRA situation a bit. Jumping into the debugger at this position right before the error:

> /home/name/work/forks/peft/src/peft/tuners/lora/layer.py(557)forward()
    555                 if not self.use_dora[active_adapter]:
    556                     import pdb;pdb.set_trace()
--> 557                     result = result + lora_B(lora_A(dropout(x))) * scaling
    558                 else:
    559                     x = dropout(x)

At this weight, we have:

ipdb>  lora_A.weight.dtype
torch.float8_e4m3fn
ipdb>  lora_B.weight.dtype
torch.float8_e4m3fn
ipdb>  x.dtype
torch.float8_e4m3fn

Then when calling lora_A(dropout(x)), the error occurs:

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != c10::Half

[1] see footnote

Given the dtypes we saw, this error is not surprising: The LoRA linear layer is upcast to float16 and then clashes with the x which is still torch.float8_e4m3fn. x should normally also be float16 though, the issue is this line in PEFT:

https://github.com/huggingface/peft/blob/4c3a76fa68e5b271ae450856873b9e6e26835227/src/peft/tuners/lora/layer.py#L553

Here, we explicitly cast x to the dtype of the LoRA weight to ensure that all the dtypes fit. Of course, we did not anticipate that the LoRA weight dtype would suddenly change during forward.

As to solutions that come to mind:

  1. Disable this new feature for LoRA
  2. Add some option to tell PEFT not to do the x casting (would be an extremely narrow feature though)
  3. Extend the pre forward hooks to cast x to the right dtype

From all these options, 3 looks best to me, even if it is bit wasteful. WDYT?

Footnote [1]: After the error, since the hooks are not cleaned up, lora_A.weight permanently stays in float16. This leaves the whole net in a bad state and makes debugging a lot harder. I would consider using a context manager to implement this feature instead of enable+disable. Then when exiting the cm, cast all modules back to their original dtype. Also, collect all the hook handles and clean them all up after exiting the context manager, WDYT?

The new usage would be:

# instead of
pipe.vae.enable_layerwise_upcasting(torch.float16)
pipe.unet.enable_layerwise_upcasting(torch.float16)
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
pipe.vae.disable_layerwise_upcasting(torch.float16)
pipe.unet.disable_layerwise_upcasting(torch.float16)

# do this
with pipe.vae.enable_layerwise_upcasting(torch.float16):
    with pipe.unet.enable_layerwise_upcasting(torch.float16):
        image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# or
with pipe.enable_layerwise_upcasting(torch.float16, components=["vae", "unet"]):
    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# or
with pipe.enable_layerwise_upcasting(vae=torch.float16, unet=torch.float16):
    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

@sayakpaul
Copy link
Member

sayakpaul commented Aug 16, 2024

The context manager tip is very nice ❤️

IMO disabling the LoRA would be a shame, though given its usage. But I am not entirely opposed to it, either.

Add some option to tell PEFT not to do the x casting (would be an extremely narrow feature though)

Could this work nicely if the changes aren't brutal on the PEFT side?

@DN6
Copy link
Collaborator Author

DN6 commented Aug 19, 2024

We could add another check for LoRA layers and upcast them entirely? I tried this locally, and it seems to work. I can push it up if you want to test @a-r-r-o-w

            from peft.tuners.tuners_utils import BaseTunerLayer
            if isinstance(module, BaseTunerLayer):
                # Upcast entire module and exist recursion
                module.register_forward_pre_hook(upcast_dtype_hook_fn)
                module.register_forward_hook(cast_to_original_dtype_hook_fn)

                return

LoRA support is pretty important so skipping it isn't a good option IMO. If we do add a check like this, then I think the upcasting should be it's own module/mixin instead of living in ModelMixin (we don't check for or have external dependencies in that class).

@BenjaminBossan I like using the context manager as well. But it would require changing the inference paradigm of diffusers a bit, additionally, implementing it at the pipeline level is tricky because we can't support this just yet for transformer models.

@BenjaminBossan
Copy link
Member

I like using the context manager as well. But it would require changing the inference paradigm of diffusers a bit, additionally, implementing it at the pipeline level is tricky because we can't support this just yet for transformer models.

Okay, I don't have enough context info to understand what changes exactly would be needed. I would still suggest, however, to collect the handles when the hooks are registered and cleaning them up when upcasting is disabled. Otherwise, when you do module._forward_pre_hooks = OrderedDict(), you run the risk of purging other hooks that are unrelated to upcasting.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@vladmandic
Copy link
Contributor

ping so bot does not mark this stale - there is massive value here once completed!

@github-actions github-actions bot removed the stale Issues that haven't received updates label Sep 15, 2024
Copy link

github-actions bot commented Oct 9, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 9, 2024
@vladmandic
Copy link
Contributor

ping so bot does not mark this stale - there is massive value here once completed - 2nd time ;)

@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Oct 9, 2024
@sayakpaul
Copy link
Member

@vladmandic indeed. Dhruv is currently on leave still some days. But this will be revived :)

@vladmandic
Copy link
Contributor

any chance this gets picked up?

@DN6
Copy link
Collaborator Author

DN6 commented Dec 19, 2024

@vladmandic I think we can get this in by release 0.33 (along with some other utilities for mixed precision inference)

@a-r-r-o-w a-r-r-o-w mentioned this pull request Dec 23, 2024
7 tasks
@a-r-r-o-w
Copy link
Member

Thanks to @DN6 for starting this! I think we can close this now since the other PR was merged and I just saw this again on the roadmap

In follow-ups, I think we can look into vlad's suggestions in #10635 and fp8 computation support eventually (perhaps after current release roadmap) :)

@a-r-r-o-w a-r-r-o-w closed this Jan 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap wip
Projects
Development

Successfully merging this pull request may close these issues.

8 participants