Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gradient checkpointing #10611

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 20, 2025

Fixes #10124.

When finetuning, we currently apply gradient checkpointing to each transformer block. This works wonders for saving memory but can lead to slower throughput. To improve throughput, at the cost of slightly higher memory usage, an acceptable compromise can be made by only checkpointing certain blocks, or by applying a different checkpointing strategy.

This PR will try to refactor how we do gradient checkpointing to enable users to use their own checkpointing functions/strategies. Currently, only LTXVideo has been updated to gather initial feedback on the changes made. If all looks well, will update all the other modeling implementations.

Benchmark
import gc

import torch
import torch.utils.benchmark as benchmark
from diffusers import LTXVideoTransformer3DModel


transformer = LTXVideoTransformer3DModel.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)

batch_size = 1
height = 512
width = 768
num_frames = 81

spatial_compression_ratio = 32
temporal_compression_ratio = 8

latent_height = height // spatial_compression_ratio
latent_width = width // spatial_compression_ratio
num_latent_frames = (num_frames - 1) // temporal_compression_ratio + 1


def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    torch.cuda.synchronize()
    return t0.blocked_autorange().mean


def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()


def create_inputs(device: torch.device, dtype: torch.dtype):
    max_sequence_length = 128

    hidden_states = torch.randn(batch_size, num_latent_frames * latent_height * latent_width, transformer.config.in_channels, device=device, dtype=dtype)
    encoder_hidden_states = torch.randn(batch_size, max_sequence_length, transformer.config.caption_channels, device=device, dtype=dtype)
    timestep = torch.randint(0, 1000, device=device, dtype=torch.int64, size=(batch_size,))
    encoder_attention_mask = torch.randint(0, 2, device=device, dtype=torch.bool, size=(batch_size, max_sequence_length))

    return {
        "hidden_states": hidden_states,
        "encoder_hidden_states": encoder_hidden_states,
        "timestep": timestep,
        "encoder_attention_mask": encoder_attention_mask,
        "num_frames": num_latent_frames,
        "height": latent_height,
        "width": latent_width,
    }


def forward_and_backward(model, **inputs):
    output = model(**inputs)[0]
    output.mean().backward()


clear_memory()

device = torch.device("cuda")
dtype = torch.bfloat16
num_iterations = 5

transformer.to(device, dtype=dtype)
torch.cuda.synchronize()

clear_memory()
print(f"Model memory: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")


# Warmup
inputs = create_inputs(device, dtype)
for _ in range(2):
    forward_and_backward(transformer, **inputs)


# Benchmark no gradient checkpointing
clear_memory()
inputs = create_inputs(device, dtype)
total_time = 0.0
for _ in range(num_iterations):
    time = benchmark_fn(forward_and_backward, transformer, **inputs)
    total_time += time
torch.cuda.synchronize()
print(f"No gradient checkpointing: memory={torch.cuda.max_memory_reserved() / 1024**3:.3f} GB, it/s={num_iterations / total_time:.3f}")


# Benchmark gradient checkpointing
clear_memory()
transformer.enable_gradient_checkpointing()
inputs = create_inputs(device, dtype)
total_time = 0.0
for _ in range(num_iterations):
    time = benchmark_fn(forward_and_backward, transformer, **inputs)
    total_time += time
torch.cuda.synchronize()
print(f"Gradient checkpointing: memory={torch.cuda.max_memory_reserved() / 1024**3:.3f} GB, it/s={num_iterations / total_time:.3f}")


# Benchmark gradient checkpointing with custom layers (every 2nd block)
clear_memory()
for index, layer in enumerate(transformer.transformer_blocks):
    layer.layer_index = index

def gradient_checkpointing_func(model, *args):
    if model.layer_index % 2 == 0:
        return torch.utils.checkpoint.checkpoint(model.__call__, *args, use_reentrant=False)
    return model(*args)

transformer.enable_gradient_checkpointing(gradient_checkpointing_func)
inputs = create_inputs(device, dtype)
total_time = 0.0
for _ in range(num_iterations):
    time = benchmark_fn(forward_and_backward, transformer, **inputs)
    total_time += time
torch.cuda.synchronize()
print(f"Gradient checkpointing (every 2nd block): memory={torch.cuda.max_memory_reserved() / 1024**3:.3f} GB, it/s={num_iterations / total_time:.3f}")


# Benchmark gradient checkpointing with custom layers (every 4th block)
clear_memory()
for index, layer in enumerate(transformer.transformer_blocks):
    layer.layer_index = index

def gradient_checkpointing_func(model, *args):
    if model.layer_index % 4 == 0:
        return torch.utils.checkpoint.checkpoint(model.__call__, *args, use_reentrant=False)
    return model(*args)

transformer.enable_gradient_checkpointing(gradient_checkpointing_func)
inputs = create_inputs(device, dtype)
total_time = 0.0
for _ in range(num_iterations):
    time = benchmark_fn(forward_and_backward, transformer, **inputs)
    total_time += time
torch.cuda.synchronize()
print(f"Gradient checkpointing (every 4th block): memory={torch.cuda.max_memory_reserved() / 1024**3:.3f} GB, it/s={num_iterations / total_time:.3f}")
Model memory: 4.025 GB
No gradient checkpointing: memory=26.719 GB, it/s=1.680
Gradient checkpointing: memory=9.316 GB, it/s=1.358
Gradient checkpointing (every 2nd block): memory=17.658 GB, it/s=1.539
Gradient checkpointing (every 4th block): memory=22.188 GB, it/s=1.658

Additonal context: #9982 (comment)

cc @bghira

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu January 20, 2025 12:38
@bghira
Copy link
Contributor

bghira commented Jan 20, 2025

a thought i have had is that we can invert the condition so that we "dont checkpoint" instead of "do checkpoint" which allows increasing memory use more finely with larger intervals

@a-r-r-o-w
Copy link
Member Author

Yep, a number of variations to try out. To not limit how to apply checkpointing, it's best to provide that control to the user, so the idea you mention should be possible to use here too. LMK if you expected something different or want to implement with a better design

@yiyixuxu
Copy link
Collaborator

are the numbers correct?

Model memory: 4.025 GB
No gradient checkpointing: memory=26.664 GB, it/s=0.582
Gradient checkpointing: memory=9.316 GB, it/s=1.339
Gradient checkpointing (every 2nd block): memory=17.658 GB, it/s=1.522
Gradient checkpointing (every 4th block): memory=22.188 GB, it/s=1.630

why is no gradient checkpoint has highest memory and lowest throughput?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 21, 2025

@yiyixuxu They were incorrect, my bad. I thought I was doing the warmup correctly for the first backward call, but it was not correct. Looking at the profiles revealed that the backward pass kernel launch was not warmed up. Updated the example so that it happens now before the "No Gradient Checkpointing" benchmark, and we're now seeing the correct numbers

"No Gradient Checkpointing" has the highest memory usage because it has to keep ALL intermediate activation tensors in memory, whereas the gradient checkpointing ones only have to save a copy of the inputs at each layer where is is applied (and intermediate activation tensors between layers if we skipping a few blocks).

The memory part was correct in previous version of the code as well. The incorrect part was the reported it/s due to badly done warmup.

@bghira
Copy link
Contributor

bghira commented Jan 21, 2025

yes this looks great, and will be a very useful addition. thank you. would this approach also work for SDXL? tracking all the checkpoints there was hard for me and i ended up monkey patching the checkpoint call in a way i'm not proud of in order to "make it work" easily.

@yiyixuxu
Copy link
Collaborator

if we just want to allow user to skip certain layers, can we go with a simpler solution?

transformer.enable_gradient_checkpointing(skip_layers= [])

then in the code

        for i, block in enumerate(self.transformer_blocks):
            if torch.is_grad_enabled() and self.gradient_checkpointing and i not in self.skip_layers:
                 ...

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 22, 2025

Skipping layers is one use case. The idea is to not limit how the checkpointing is applied and just give the control to the user (with us also providing sensible default behaviour of torch.nn.utils.checkpoint) - for example, using a different provider for checkpointing, such as deepspeed, instead of monkey patching the forward pass or using other intrusive solution, or allowing use of custom checkpoint implementations that can perform CPU offloading of stored inputs and retrieve them back when required for recomputation

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 22, 2025

@yiyixuxu LMK if you think the current changes look good and I'll propagate and add tests for all other models to make sure this works as expected

Edit: oh, just saw your message - we commented at the same time. I'll work on finishing this up

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Jan 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

[models] allow specifying block numbers in enable_gradient_checkpointing() to speed up training
3 participants