Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Layerwise dynamic upcasting to Diffusers Models #9177

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def forward(
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/autoencoders/autoencoder_asym_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""

_always_upcast_modules = ["MaskConditionDecoder"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add _always_upcast_modules to ModelMixin?


@register_to_config
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_always_upcast_modules = ["Decoder"]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_always_upcast_modules = ["TemporalDecoder"]

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_oobleck.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = False
_always_upcast_modules = ["OobleckEncoder", "OobleckDecoder"]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def decode(
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.

"""
z = (z * self.config.scaling_factor - self.means) / self.stds
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this doesn't lead to performance regressions.


scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/autoencoders/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""

_always_upcast_modules = ["Decoder", "VectorQuantizer"]

@register_to_config
def __init__(
self,
Expand Down
74 changes: 74 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,80 @@ def disable_xformers_memory_efficient_attention(self) -> None:
"""
self.set_use_memory_efficient_attention_xformers(False)

def enable_layerwise_upcasting(self, upcast_dtype=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with the name because it conveys the idea easily. Maybe we could note down the implementation detail that it is applied to the leaf nodes in the compute graph in the docstring?

r"""
Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the
individual modules in the model to the appropriate dtype right before the foward pass.

The module is then moved back to the low memory dtype after the foward pass.
"""

upcast_dtype = upcast_dtype or torch.float32
original_dtype = self.dtype

def upcast_dtype_hook_fn(module, *args, **kwargs):
module = module.to(upcast_dtype)

def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
module = module.to(original_dtype)

def fn_recursive_upcast(module):
"""In certain cases modules will apply casting internally or reference the dtype of internal blocks.

e.g.

```
class MyModel(nn.Module):
def forward(self, x):
dtype = next(iter(self.blocks.parameters())).dtype
x = self.blocks(x) + torch.ones(x.size()).to(dtype)
```
Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until
their `forward` method is called. We need to add the upcast hook on the entire module in order for the
operation to work.

The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast
entirely, rather than layerwise.

"""
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
# Upcast entire module and exist recursion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Upcast entire module and exist recursion
# Upcast entire module and exit recursion

module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)

return

has_children = list(module.children())
if not has_children:
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)

for child in module.children():
fn_recursive_upcast(child)

for module in self.children():
fn_recursive_upcast(module)

def disable_layerwise_upcasting(self):
def fn_recursive_upcast(module):
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()

return

has_children = list(module.children())
if not has_children:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()

for child in module.children():
fn_recursive_upcast(child)

for module in self.children():
fn_recursive_upcast(module)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):

_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
_supports_gradient_checkpointing = True
_always_upcast_modules = ["AuraFlowPatchEmbed"]

@register_to_config
def __init__(
Expand Down Expand Up @@ -457,11 +458,15 @@ def forward(

# Apply patch embedding, timestep embedding, and project the caption embeddings.
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
temb = self.time_step_embed(timestep).to(dtype=hidden_states.dtype)
temb = self.time_step_proj(temb)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_states = torch.cat(
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
[
self.register_tokens.to(encoder_hidden_states.dtype).repeat(encoder_hidden_states.size(0), 1, 1),
encoder_hidden_states,
],
dim=1,
)

# MMDiT blocks.
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_always_upcast_modules = ["PatchEmbed"]

@register_to_config
def __init__(
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/transformers/hunyuan_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""

_always_upcast_modules = ["HunyuanDiTAttentionPool"]

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -484,7 +486,9 @@ def forward(
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()

encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
encoder_hidden_states = torch.where(
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding.to(encoder_hidden_states.dtype)
)

skips = []
for layer, block in enumerate(self.blocks):
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_always_upcast_modules = ["PatchEmbed"]

@register_to_config
def __init__(
Expand Down Expand Up @@ -301,7 +302,9 @@ def forward(
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])

embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
shift, scale = (self.scale_shift_table[None].to(embedded_timestep.dtype) + embedded_timestep[:, None]).chunk(
2, dim=1
)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_always_upcast_modules = ["PatchEmbed"]

@register_to_config
def __init__(
Expand Down Expand Up @@ -422,7 +423,8 @@ def custom_forward(*inputs):

# 3. Output
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
self.scale_shift_table[None].to(embedded_timestep.dtype)
+ embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/prior_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def forward(

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
timesteps_projected = timesteps_projected.to(dtype=hidden_states.dtype)
time_embeddings = self.time_embedding(timesteps_projected)

if self.embedding_proj_norm is not None:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
"""

_supports_gradient_checkpointing = True
_always_upcast_modules = ["PatchEmbed"]

@register_to_config
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def forward(
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)

if self.class_embedding is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def forward(
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def forward(
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = self.time_embedding(t_emb, timestep_cond)

# 2. FPS
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,7 +2152,7 @@ def forward(
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
Expand Down
45 changes: 45 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
disable_full_determinism,
enable_full_determinism,
get_python_version,
is_torch_compile,
require_torch_2,
Expand Down Expand Up @@ -984,6 +986,49 @@ def test_sharded_checkpoints_device_map(self):
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
def test_layerwise_upcasting(self):
disable_full_determinism()

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()

torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model.to(torch_device)

model(**inputs_dict)
base_max_memory = torch.cuda.max_memory_allocated()

# Remove model
model.to("cpu")
del model

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()

low_memory_dtype = torch.float8_e4m3fn
upcast_dtype = torch.float32

config, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
low_mem_model = self.model_class(**config).eval()
low_mem_model.to(low_memory_dtype)
low_mem_model.to(torch_device)
layerwise_max_memory = torch.cuda.max_memory_allocated()
low_mem_model.enable_layerwise_upcasting(upcast_dtype)
low_mem_model(**inputs_dict)

assert layerwise_max_memory < base_max_memory

enable_full_determinism()


@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
Expand Down
Loading