-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Layerwise dynamic upcasting to Diffusers Models #9177
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@DN6, are you suggesting NOT adding any integrations like those in #9174? If so, could you please elaborate your reasoning? It will be in a similar spirit as how it's done in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks very clean and well thought.
I left some comments but most of them are quite minor in nature.
We should add a doc about it in the memory optimization section: https://huggingface.co/docs/diffusers/optimization/memory. Pinging @stevhliu what would be the best location to add about this technique?
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): | |||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. | |||
""" | |||
|
|||
_always_upcast_modules = ["MaskConditionDecoder"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add _always_upcast_modules
to ModelMixin
?
@@ -330,7 +330,7 @@ def decode( | |||
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output. | |||
|
|||
""" | |||
z = (z * self.config.scaling_factor - self.means) / self.stds | |||
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully this doesn't lead to performance regressions.
|
||
""" | ||
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules: | ||
# Upcast entire module and exist recursion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Upcast entire module and exist recursion | |
# Upcast entire module and exit recursion |
@@ -263,6 +263,80 @@ def disable_xformers_memory_efficient_attention(self) -> None: | |||
""" | |||
self.set_use_memory_efficient_attention_xformers(False) | |||
|
|||
def enable_layerwise_upcasting(self, upcast_dtype=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am okay with the name because it conveys the idea easily. Maybe we could note down the implementation detail that it is applied to the leaf nodes in the compute graph in the docstring?
The memory optimization doc you mention is perfect @sayakpaul 💯 |
Thanks for this, so cool! I've tried it out with AnimateDiff and here are the results: Duration
Codeimport argparse
import gc
import time
import torch
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter, UNet2DConditionModel
from diffusers.utils import export_to_video
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.2f}")
print(f"{max_memory=:.2f}")
print(f"{max_reserved=:.2f}")
def main(load_dtype: torch.dtype, upcast_dtype: torch.dtype = torch.float16, filename: str = "output.mp4", decode_chunk_size: int = 16, device: str = "cuda"):
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3", torch_dtype=load_dtype)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=load_dtype)
unet = UNet2DConditionModel.from_pretrained("stablediffusionapi/darksushimixv225", subfolder="unet", torch_dtype=load_dtype)
pipe = AnimateDiffPipeline.from_pretrained("stablediffusionapi/darksushimixv225", motion_adapter=motion_adapter, vae=vae, unet=unet, torch_dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.vae.enable_layerwise_upcasting(upcast_dtype)
pipe.unet.enable_layerwise_upcasting(upcast_dtype)
pipe.to(device)
reset_memory(device)
print("VAE dtype:", pipe.vae.dtype)
print("UNet dtype:", pipe.unet.dtype)
print_memory(device)
width = 512
height = 512
prompt = "1boy, snowy winter day, trees in the background, snowflakes on leaves, close up"
negative_prompt = "bad quality, worst quality"
num_frames = 16
guidance_scale = 7
num_inference_steps = 20
reset_memory(device)
start_time = time.time()
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(0),
decode_chunk_size=decode_chunk_size,
)
end_time = time.time()
print_memory(device)
print(f"Time (denoising + decoding): {end_time - start_time:.2f}")
frames = output.frames[0]
export_to_video(frames, filename, fps=8)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--load_dtype", type=str, default="fp8_e5m2")
parser.add_argument("--upcast_dtype", type=str, default="fp16")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--decode_chunk_size", type=int, default=16)
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp8_e5m2": torch.float8_e5m2,
"fp8_e4m3": torch.float8_e4m3fn,
}
if __name__ == "__main__":
args = get_args()
load_dtype = DTYPE_MAPPING[args.load_dtype]
upcast_dtype = DTYPE_MAPPING[args.upcast_dtype]
output_filename = f"animatediff_{args.load_dtype}_{args.upcast_dtype}.mp4"
main(load_dtype, upcast_dtype, output_filename, args.decode_chunk_size, args.device) The loaded memory usage is half as expected. The max memory here is due to vae decode. It can be lowered by changing the The results are different for each floating type though (tried different prompts, seed, etc.), but this is probably expected.
I also tried using it with the new FreeNoise branch to test long video generation: Codeimport argparse
import gc
import time
import torch
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter, UNet2DConditionModel
from diffusers.utils import export_to_video
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.2f}")
print(f"{max_memory=:.2f}")
print(f"{max_reserved=:.2f}")
def main(load_dtype: torch.dtype, upcast_dtype: torch.dtype = torch.float16, filename: str = "output.mp4", device: str = "cuda"):
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=load_dtype)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=load_dtype)
unet = UNet2DConditionModel.from_pretrained("stablediffusionapi/darksushimixv225", subfolder="unet", torch_dtype=load_dtype)
pipe = AnimateDiffPipeline.from_pretrained("stablediffusionapi/darksushimixv225", motion_adapter=motion_adapter, vae=vae, unet=unet, torch_dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora")
pipe.set_adapters(["lcm_lora"], [0.8])
pipe.enable_free_noise(context_length=16, context_stride=4)
pipe.vae.enable_layerwise_upcasting(upcast_dtype)
pipe.unet.enable_layerwise_upcasting(upcast_dtype)
pipe.to(device)
reset_memory(device)
print("VAE dtype:", pipe.vae.dtype)
print("UNet dtype:", pipe.unet.dtype)
print_memory(device)
width = 512
height = 512
prompt = {
0: "a woman on a winter day, sparkly leaves in the background, snow flakes, close up",
80: "a woman on a summer day, trees visible in the background, close up",
160: "a woman on a autumn day, yellow leaves in the background, close up",
240: "a woman on a rainy day, tropical leaves in the background, close up",
}
negative_prompt = "bad quality, worst quality"
num_frames = 256
guidance_scale = 2.5
num_inference_steps = 10
reset_memory(device)
start_time = time.time()
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(0),
)
end_time = time.time()
print_memory(device)
print(f"Time (denoising + decoding): {end_time - start_time:.2f}")
frames = output.frames[0]
export_to_video(frames, filename, fps=16)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--load_dtype", type=str, default="fp8_e5m2")
parser.add_argument("--upcast_dtype", type=str, default="fp16")
parser.add_argument("--device", type=str, default="cuda")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp8_e5m2": torch.float8_e5m2,
"fp8_e4m3": torch.float8_e4m3fn,
}
if __name__ == "__main__":
args = get_args()
load_dtype = DTYPE_MAPPING[args.load_dtype]
upcast_dtype = DTYPE_MAPPING[args.upcast_dtype]
output_filename = f"animatediff_{args.load_dtype}_{args.upcast_dtype}_multiprompt.mp4"
main(load_dtype, upcast_dtype, output_filename, args.device) LogsLoading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 16.01it/s]
VAE dtype: torch.float8_e5m2
UNet dtype: torch.float8_e5m2
memory=2.04
max_memory=2.04
max_reserved=2.07
0%| | 0/10 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/aryan/work/diffusers/workflows/animatediff_layerwise_upcasting.py", line 104, in <module>
main(load_dtype, upcast_dtype, output_filename, args.device)
File "/home/aryan/work/diffusers/workflows/animatediff_layerwise_upcasting.py", line 62, in main
output = pipe(
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff.py", line 815, in __call__
noise_pred = self.unet(
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/models/unets/unet_motion_model.py", line 2200, in forward
sample, res_samples = downsample_block(
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/models/unets/unet_motion_model.py", line 543, in forward
hidden_states = resnet(hidden_states, temb)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aryan/work/diffusers/src/diffusers/models/resnet.py", line 341, in forward
hidden_states = self.conv1(hidden_states)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 1040, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
result = forward_call(*args, **kwargs)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (CUDAFloat8_e5m2Type) and weight type (torch.cuda.HalfTensor) should be the same |
Regarding the error, does it happen without LoRAs too? A simpler and more minimal reproduction would be nice.
I would say that is expected because we're naively downcasting and upcasting the params here. In a way, we can think of that as some form of compression and decompression. Clever decompression would consider the quantization scale when performing the computations, as well. Libraries like quanto and bitsandbytes already do that when relevant. Also, the first code snippet seems to be missing. |
Oh sorry, edited the comment. Here's the minimal reproducer: Minimal reproducerimport torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
model_id = "runwayml/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float8_e4m3fn)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float8_e4m3fn)
pipe = StableDiffusionPipeline.from_pretrained(model_id, vae=vae, unet=unet, torch_dtype=torch.float16)
pipe.load_lora_weights("sayakpaul/sd-model-finetuned-lora-t4")
pipe.vae.enable_layerwise_upcasting(torch.float16)
pipe.unet.enable_layerwise_upcasting(torch.float16)
pipe.to("cuda")
print("VAE:", pipe.vae.dtype)
print("UNet:", pipe.unet.dtype)
prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png") Stack trace
And if we try fusing the lora: Logs
|
Okay, this is a really good finding. Thank you! As we can see in the stack trace:
It's rather a problem coming from - z = (z * self.config.scaling_factor - self.means) / self.stds
+ z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
What happens if you call |
This is what I'm trying: Codeimport torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
model_id = "runwayml/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float8_e4m3fn)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float8_e4m3fn)
pipe = StableDiffusionPipeline.from_pretrained(model_id, vae=vae, unet=unet, torch_dtype=torch.float16)
pipe.load_lora_weights("sayakpaul/sd-model-finetuned-lora-t4")
pipe.to("cuda")
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.vae.enable_layerwise_upcasting(torch.float16)
pipe.unet.enable_layerwise_upcasting(torch.float16)
print("VAE:", pipe.vae.dtype)
print("UNet:", pipe.unet.dtype)
prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png") Traceback
This fails in calculating the delta weight to be added to linear layers since float8 matmul is not supported. |
is the problem as the |
But isn't loading the models in fp8 or lower what we want, and doing the computation in a GPU supported dtype like fp16? |
We do but rightly |
I tried to debug the LoRA situation a bit. Jumping into the debugger at this position right before the error: > /home/name/work/forks/peft/src/peft/tuners/lora/layer.py(557)forward()
555 if not self.use_dora[active_adapter]:
556 import pdb;pdb.set_trace()
--> 557 result = result + lora_B(lora_A(dropout(x))) * scaling
558 else:
559 x = dropout(x) At this weight, we have:
Then when calling
[1] see footnote Given the dtypes we saw, this error is not surprising: The LoRA linear layer is upcast to float16 and then clashes with the Here, we explicitly cast As to solutions that come to mind:
From all these options, 3 looks best to me, even if it is bit wasteful. WDYT? Footnote [1]: After the error, since the hooks are not cleaned up, The new usage would be: # instead of
pipe.vae.enable_layerwise_upcasting(torch.float16)
pipe.unet.enable_layerwise_upcasting(torch.float16)
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
pipe.vae.disable_layerwise_upcasting(torch.float16)
pipe.unet.disable_layerwise_upcasting(torch.float16)
# do this
with pipe.vae.enable_layerwise_upcasting(torch.float16):
with pipe.unet.enable_layerwise_upcasting(torch.float16):
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
# or
with pipe.enable_layerwise_upcasting(torch.float16, components=["vae", "unet"]):
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
# or
with pipe.enable_layerwise_upcasting(vae=torch.float16, unet=torch.float16):
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] |
The context manager tip is very nice ❤️ IMO disabling the LoRA would be a shame, though given its usage. But I am not entirely opposed to it, either.
Could this work nicely if the changes aren't brutal on the PEFT side? |
We could add another check for LoRA layers and upcast them entirely? I tried this locally, and it seems to work. I can push it up if you want to test @a-r-r-o-w from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
# Upcast entire module and exist recursion
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
return LoRA support is pretty important so skipping it isn't a good option IMO. If we do add a check like this, then I think the upcasting should be it's own module/mixin instead of living in ModelMixin (we don't check for or have external dependencies in that class). @BenjaminBossan I like using the context manager as well. But it would require changing the inference paradigm of diffusers a bit, additionally, implementing it at the pipeline level is tricky because we can't support this just yet for |
Okay, I don't have enough context info to understand what changes exactly would be needed. I would still suggest, however, to collect the handles when the hooks are registered and cleaning them up when upcasting is disabled. Otherwise, when you do |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
ping so bot does not mark this stale - there is massive value here once completed! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
ping so bot does not mark this stale - there is massive value here once completed - 2nd time ;) |
@vladmandic indeed. Dhruv is currently on leave still some days. But this will be revived :) |
any chance this gets picked up? |
@vladmandic I think we can get this in by release 0.33 (along with some other utilities for mixed precision inference) |
What does this PR do?
Proposal to add the option to enable layerwise dynamic upcasting in order to save memory when running large models such as Flux.
This PR
Adds an
enable_layerwise_upcasting
method to the ModelMixin class. This method recursively adds pre forward and forward hooks to each module in a model so that we can load/keep the model is a low memory dtype on GPU e.g.torch.float8_e4m3fn
and upcast only before the forward method of a module is called. I believe Comfy does something similar.Introduce a class attribute
_always_upcast_modules
. In certain cases, modules will use/reference the dtype of a top level module. This means that we need to apply dynamic upcasting to the entire top level module and not just the layers inside it. This attributes lets us know which modules to apply this to.This approach does bring in considerable memory savings. Here is the VRAM usage for running Flux on a T4 with model offloading
TODOs/Considerations:
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.