From c5ce24f960641e5f82d8655fd510341560b30aa4 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Tue, 21 Jan 2025 11:51:47 +0000 Subject: [PATCH 01/11] rewrite memory count without implicitly using dimensions by @ic-synth --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 941b3eb07f10..a797ad6aed6c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -41,9 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = ( - (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 - ) + memory_count = torch.prod(torch.tensor(input.shape)) * 2 / 1024**3 # Set to 2GB, suitable for CuDNN if memory_count > 2: From c586b4ba949d1f33c68548f6919317ef58d7d4dd Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Tue, 21 Jan 2025 11:52:58 +0000 Subject: [PATCH 02/11] replace F.pad by built-in padding in Conv3D --- .../models/autoencoders/autoencoder_kl_cogvideox.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index a797ad6aed6c..e796edf07bdb 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -103,6 +103,7 @@ def __init__( self.width_pad = width_pad self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + self.const_padding_conv3d = (0, self.width_pad, self.height_pad) self.temporal_dim = 2 self.time_kernel_size = time_kernel_size @@ -115,6 +116,8 @@ def __init__( kernel_size=kernel_size, stride=stride, dilation=dilation, + padding = 0 if self.pad_mode == 'replicate' else self.const_padding_conv3d, + padding_mode = 'zeros', ) def fake_context_parallel_forward( @@ -135,9 +138,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non if self.pad_mode == "replicate": conv_cache = None else: - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - inputs = F.pad(inputs, padding_2d, mode="constant", value=0) output = self.conv(inputs) return output, conv_cache From 272537b32e4192128aead5c5d14ba2defecd0553 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Tue, 21 Jan 2025 11:53:29 +0000 Subject: [PATCH 03/11] in-place sums to reduce memory allocations --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index e796edf07bdb..1986afc6e630 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -303,7 +303,7 @@ def forward( hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) if temb is not None: - hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + hidden_states.add_(self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]) if zq is not None: hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) @@ -322,7 +322,7 @@ def forward( else: inputs = self.conv_shortcut(inputs) - hidden_states = hidden_states + inputs + hidden_states.add_(inputs) return hidden_states, new_conv_cache From b0c826f7df8ce281eede1b9a42b5f576ccb7efa2 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Tue, 21 Jan 2025 20:56:49 +0000 Subject: [PATCH 04/11] fixed trailing whitespace --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 1986afc6e630..62e9306d20bd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -103,7 +103,7 @@ def __init__( self.width_pad = width_pad self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - self.const_padding_conv3d = (0, self.width_pad, self.height_pad) + self.const_padding_conv3d = (0, self.width_pad, self.height_pad) self.temporal_dim = 2 self.time_kernel_size = time_kernel_size From 249a8a275dda0f4bf29328dbebe11228ce69fe1c Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Tue, 21 Jan 2025 22:13:45 +0000 Subject: [PATCH 05/11] file reformatted --- .../models/autoencoders/autoencoder_kl_cogvideox.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 62e9306d20bd..9dec58552b63 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -103,7 +103,7 @@ def __init__( self.width_pad = width_pad self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - self.const_padding_conv3d = (0, self.width_pad, self.height_pad) + self.const_padding_conv3d = (0, self.width_pad, self.height_pad) self.temporal_dim = 2 self.time_kernel_size = time_kernel_size @@ -116,8 +116,8 @@ def __init__( kernel_size=kernel_size, stride=stride, dilation=dilation, - padding = 0 if self.pad_mode == 'replicate' else self.const_padding_conv3d, - padding_mode = 'zeros', + padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d, + padding_mode="zeros", ) def fake_context_parallel_forward( From a0c5ab238343f5e93318b8aaa99c7177b650d1e9 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Fri, 24 Jan 2025 16:22:35 +0000 Subject: [PATCH 06/11] in-place sums --- .../models/transformers/cogvideox_transformer_3d.py | 8 ++++---- .../models/transformers/transformer_cogview3plus.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 51634780692d..32a5d11bf629 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -138,8 +138,8 @@ def forward( **attention_kwargs, ) - hidden_states = hidden_states + gate_msa * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + hidden_states.add_(gate_msa * attn_hidden_states) + encoder_hidden_states.add_(enc_gate_msa * attn_encoder_hidden_states) # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( @@ -150,8 +150,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + hidden_states.add_(gate_ff * ff_output[:, text_seq_length:]) + encoder_hidden_states.add_(enc_gate_ff * ff_output[:, :text_seq_length]) return hidden_states, encoder_hidden_states diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 369509a3a35e..57a6f33306a1 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -106,8 +106,8 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states ) - hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + hidden_states.add_(gate_msa.unsqueeze(1) * attn_hidden_states) + encoder_hidden_states.add_(c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states) # norm & modulate norm_hidden_states = self.norm2(hidden_states) @@ -120,8 +120,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + hidden_states.add_(gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]) + encoder_hidden_states.add_(c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]) if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) From 774700789e4aca2a451012307c98adcb4c94eec7 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Sat, 25 Jan 2025 11:14:22 +0000 Subject: [PATCH 07/11] simpler in-place expressions --- .../models/autoencoders/autoencoder_kl_cogvideox.py | 4 ++-- .../models/transformers/cogvideox_transformer_3d.py | 8 ++++---- .../models/transformers/transformer_cogview3plus.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 9dec58552b63..a78b8a712fe8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -303,7 +303,7 @@ def forward( hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) if temb is not None: - hidden_states.add_(self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]) + hidden_states += self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) @@ -322,7 +322,7 @@ def forward( else: inputs = self.conv_shortcut(inputs) - hidden_states.add_(inputs) + hidden_states += inputs return hidden_states, new_conv_cache diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 32a5d11bf629..a71feb0a5cf3 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -138,8 +138,8 @@ def forward( **attention_kwargs, ) - hidden_states.add_(gate_msa * attn_hidden_states) - encoder_hidden_states.add_(enc_gate_msa * attn_encoder_hidden_states) + hidden_states += gate_msa * attn_hidden_states + encoder_hidden_states += enc_gate_msa * attn_encoder_hidden_states # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( @@ -150,8 +150,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states.add_(gate_ff * ff_output[:, text_seq_length:]) - encoder_hidden_states.add_(enc_gate_ff * ff_output[:, :text_seq_length]) + hidden_states += gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states += enc_gate_ff * ff_output[:, :text_seq_length] return hidden_states, encoder_hidden_states diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 57a6f33306a1..b0696daab62c 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -106,8 +106,8 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states ) - hidden_states.add_(gate_msa.unsqueeze(1) * attn_hidden_states) - encoder_hidden_states.add_(c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states) + hidden_states += gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states += c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states # norm & modulate norm_hidden_states = self.norm2(hidden_states) @@ -120,8 +120,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states.add_(gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]) - encoder_hidden_states.add_(c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]) + hidden_states += gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states += c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) From 4bd9e991849a7b4b91b963da297956d1a01f2534 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Sat, 25 Jan 2025 11:22:36 +0000 Subject: [PATCH 08/11] removed in-place sum, may affect backward propagation logic --- .../models/transformers/cogvideox_transformer_3d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index a71feb0a5cf3..51634780692d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -138,8 +138,8 @@ def forward( **attention_kwargs, ) - hidden_states += gate_msa * attn_hidden_states - encoder_hidden_states += enc_gate_msa * attn_encoder_hidden_states + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( @@ -150,8 +150,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states += gate_ff * ff_output[:, text_seq_length:] - encoder_hidden_states += enc_gate_ff * ff_output[:, :text_seq_length] + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] return hidden_states, encoder_hidden_states From 2ce09f8a7f38b4476e618d8d69b9950905170073 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Sat, 25 Jan 2025 11:23:25 +0000 Subject: [PATCH 09/11] removed in-place sum, may affect backward propagation logic --- .../models/transformers/transformer_cogview3plus.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index b0696daab62c..369509a3a35e 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -106,8 +106,8 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states ) - hidden_states += gate_msa.unsqueeze(1) * attn_hidden_states - encoder_hidden_states += c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states # norm & modulate norm_hidden_states = self.norm2(hidden_states) @@ -120,8 +120,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states += gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] - encoder_hidden_states += c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) From 8db178f79f570487d5796a7520bf48528d72208b Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Sat, 25 Jan 2025 11:24:02 +0000 Subject: [PATCH 10/11] removed in-place sum, may affect backward propagation logic --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index a78b8a712fe8..9d0a71fd79f5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -303,7 +303,7 @@ def forward( hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) if temb is not None: - hidden_states += self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) @@ -322,7 +322,7 @@ def forward( else: inputs = self.conv_shortcut(inputs) - hidden_states += inputs + hidden_states = hidden_states + inputs return hidden_states, new_conv_cache From 0e1a22a33c042e55d6ae0f22a644768103824355 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Thu, 13 Mar 2025 01:32:11 +0000 Subject: [PATCH 11/11] reverted change --- src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index fea316687756..e2b26396899f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)) * 2 / 1024**3 + memory_count = ( + (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 + ) # Set to 2GB, suitable for CuDNN if memory_count > 2: