diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e6858d842cbb..7ff16b0e90aa 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -449,7 +449,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.norm_type == "ada_norm_single": shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1) ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py index 3f4d46557bf7..3c16b766c23d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ + _always_upcast_modules = ["MaskConditionDecoder"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 161770c67cf8..fd197d4af983 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -70,6 +70,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _always_upcast_modules = ["Decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 55449644ed03..d35f39cb2e51 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -192,6 +192,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["TemporalDecoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py index e8e372a709d7..7ca1a80ff02b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py +++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = False + _always_upcast_modules = ["OobleckEncoder", "OobleckDecoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index a97249f79473..92d251597e4d 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -330,7 +330,7 @@ def decode( Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output. """ - z = (z * self.config.scaling_factor - self.means) / self.stds + z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype) scale_factor = 2 ** (len(self.config.block_out_channels) - 1) z = F.interpolate(z, mode="nearest", scale_factor=scale_factor) diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index ae8a118d719a..2b4b27f3f11a 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ + _always_upcast_modules = ["Decoder", "VectorQuantizer"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index cfe692dcc54a..9f23b1bef396 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -263,6 +263,80 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) + def enable_layerwise_upcasting(self, upcast_dtype=None): + r""" + Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g. + torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the + individual modules in the model to the appropriate dtype right before the foward pass. + + The module is then moved back to the low memory dtype after the foward pass. + """ + + upcast_dtype = upcast_dtype or torch.float32 + original_dtype = self.dtype + + def upcast_dtype_hook_fn(module, *args, **kwargs): + module = module.to(upcast_dtype) + + def cast_to_original_dtype_hook_fn(module, *args, **kwargs): + module = module.to(original_dtype) + + def fn_recursive_upcast(module): + """In certain cases modules will apply casting internally or reference the dtype of internal blocks. + + e.g. + + ``` + class MyModel(nn.Module): + def forward(self, x): + dtype = next(iter(self.blocks.parameters())).dtype + x = self.blocks(x) + torch.ones(x.size()).to(dtype) + ``` + Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until + their `forward` method is called. We need to add the upcast hook on the entire module in order for the + operation to work. + + The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast + entirely, rather than layerwise. + + """ + if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules: + # Upcast entire module and exist recursion + module.register_forward_pre_hook(upcast_dtype_hook_fn) + module.register_forward_hook(cast_to_original_dtype_hook_fn) + + return + + has_children = list(module.children()) + if not has_children: + module.register_forward_pre_hook(upcast_dtype_hook_fn) + module.register_forward_hook(cast_to_original_dtype_hook_fn) + + for child in module.children(): + fn_recursive_upcast(child) + + for module in self.children(): + fn_recursive_upcast(module) + + def disable_layerwise_upcasting(self): + def fn_recursive_upcast(module): + if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules: + module._forward_pre_hooks = OrderedDict() + module._forward_hooks = OrderedDict() + + return + + has_children = list(module.children()) + if not has_children: + module._forward_pre_hooks = OrderedDict() + module._forward_hooks = OrderedDict() + + for child in module.children(): + fn_recursive_upcast(child) + + for module in self.children(): + fn_recursive_upcast(module) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index ad64df0c0790..c1179ba95134 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] _supports_gradient_checkpointing = True + _always_upcast_modules = ["AuraFlowPatchEmbed"] @register_to_config def __init__( @@ -457,11 +458,15 @@ def forward( # Apply patch embedding, timestep embedding, and project the caption embeddings. hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. - temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) + temb = self.time_step_embed(timestep).to(dtype=hidden_states.dtype) temb = self.time_step_proj(temb) encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = torch.cat( - [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 + [ + self.register_tokens.to(encoder_hidden_states.dtype).repeat(encoder_hidden_states.size(0), 1, 1), + encoder_hidden_states, + ], + dim=1, ) # MMDiT blocks. diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 9f8957737dbc..4800a4142b36 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -65,6 +65,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["PatchEmbed"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 7f3dab220aaa..49d8490e606b 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ + _always_upcast_modules = ["HunyuanDiTAttentionPool"] + @register_to_config def __init__( self, @@ -484,7 +486,9 @@ def forward( text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() - encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + encoder_hidden_states = torch.where( + text_embedding_mask, encoder_hidden_states, self.text_embedding_padding.to(encoder_hidden_states.dtype) + ) skips = [] for layer, block in enumerate(self.blocks): diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 71d19216e5ff..2f28e1626e1f 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -64,6 +64,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): video_length (`int`, *optional*): The number of frames in the video-like data. """ + _always_upcast_modules = ["PatchEmbed"] @register_to_config def __init__( @@ -301,7 +302,9 @@ def forward( hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + shift, scale = (self.scale_shift_table[None].to(embedded_timestep.dtype) + embedded_timestep[:, None]).chunk( + 2, dim=1 + ) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 1e5cd5794517..a6933009b812 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _always_upcast_modules = ["PatchEmbed"] @register_to_config def __init__( @@ -422,7 +423,8 @@ def custom_forward(*inputs): # 3. Output shift, scale = ( - self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) + self.scale_shift_table[None].to(embedded_timestep.dtype) + + embedded_timestep[:, None].to(self.scale_shift_table.device) ).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index fdb67384ff5e..9384a528978b 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -289,7 +289,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. - timesteps_projected = timesteps_projected.to(dtype=self.dtype) + timesteps_projected = timesteps_projected.to(dtype=hidden_states.dtype) time_embeddings = self.time_embedding(timesteps_projected) if self.embedding_proj_norm is not None: diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 9376c91d0756..c898ab44c5d6 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -54,6 +54,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["PatchEmbed"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 5972505f2897..c99107af590f 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -283,7 +283,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 3081fdc4700c..90e7c3797bfd 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -641,7 +641,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 6ab3a577b892..f12031bf9bf1 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -590,7 +590,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) t_emb = self.time_embedding(t_emb, timestep_cond) # 2. FPS diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 73c9c70c4a11..04ca95c78d3b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2152,7 +2152,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 64722e2d9797..861f4ee99ba3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -43,6 +43,8 @@ from diffusers.utils.hub_utils import _add_variant from diffusers.utils.testing_utils import ( CaptureLogger, + disable_full_determinism, + enable_full_determinism, get_python_version, is_torch_compile, require_torch_2, @@ -984,6 +986,49 @@ def test_sharded_checkpoints_device_map(self): new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_torch_gpu + def test_layerwise_upcasting(self): + disable_full_determinism() + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_cached() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model.to(torch_device) + + model(**inputs_dict) + base_max_memory = torch.cuda.max_memory_allocated() + + # Remove model + model.to("cpu") + del model + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_cached() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + low_memory_dtype = torch.float8_e4m3fn + upcast_dtype = torch.float32 + + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + low_mem_model = self.model_class(**config).eval() + low_mem_model.to(low_memory_dtype) + low_mem_model.to(torch_device) + layerwise_max_memory = torch.cuda.max_memory_allocated() + low_mem_model.enable_layerwise_upcasting(upcast_dtype) + low_mem_model(**inputs_dict) + + assert layerwise_max_memory < base_max_memory + + enable_full_determinism() + @is_staging_test class ModelPushToHubTester(unittest.TestCase):