Skip to content

Commit fe2b397

Browse files
authored
remove unnecessary call to F.pad (#10620)
* rewrite memory count without implicitly using dimensions by @ic-synth * replace F.pad by built-in padding in Conv3D * in-place sums to reduce memory allocations * fixed trailing whitespace * file reformatted * in-place sums * simpler in-place expressions * removed in-place sum, may affect backward propagation logic * removed in-place sum, may affect backward propagation logic * removed in-place sum, may affect backward propagation logic * reverted change
1 parent be0b7f5 commit fe2b397

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
self.width_pad = width_pad
106106
self.time_pad = time_pad
107107
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
108+
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
108109

109110
self.temporal_dim = 2
110111
self.time_kernel_size = time_kernel_size
@@ -117,6 +118,8 @@ def __init__(
117118
kernel_size=kernel_size,
118119
stride=stride,
119120
dilation=dilation,
121+
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
122+
padding_mode="zeros",
120123
)
121124

122125
def fake_context_parallel_forward(
@@ -137,9 +140,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non
137140
if self.pad_mode == "replicate":
138141
conv_cache = None
139142
else:
140-
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
141143
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
142-
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143144

144145
output = self.conv(inputs)
145146
return output, conv_cache

0 commit comments

Comments
 (0)