Skip to content

Commit 5878e2f

Browse files
committed
implement tiled encode/decode
1 parent edd7880 commit 5878e2f

File tree

2 files changed

+263
-39
lines changed

2 files changed

+263
-39
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

+217-38
Original file line numberDiff line numberDiff line change
@@ -677,42 +677,7 @@ def __init__(
677677
attn_scales: List[float] = [],
678678
temperal_downsample: List[bool] = [False, True, True],
679679
dropout: float = 0.0,
680-
latents_mean: List[float] = [
681-
-0.7571,
682-
-0.7089,
683-
-0.9113,
684-
0.1075,
685-
-0.1745,
686-
0.9653,
687-
-0.1517,
688-
1.5508,
689-
0.4134,
690-
-0.0715,
691-
0.5517,
692-
-0.3632,
693-
-0.1922,
694-
-0.9497,
695-
0.2503,
696-
-0.2921,
697-
],
698-
latents_std: List[float] = [
699-
2.8184,
700-
1.4541,
701-
2.3275,
702-
2.6558,
703-
1.2196,
704-
1.7708,
705-
2.6052,
706-
2.0743,
707-
3.2687,
708-
2.1526,
709-
2.8652,
710-
1.5579,
711-
1.6382,
712-
1.1253,
713-
2.8251,
714-
1.9160,
715-
],
680+
spatial_compression_ratio: int = 8,
716681
) -> None:
717682
super().__init__()
718683

@@ -730,6 +695,58 @@ def __init__(
730695
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
731696
)
732697

698+
self.spatial_compression_ratio = spatial_compression_ratio
699+
700+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
701+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
702+
# intermediate tiles together, the memory requirement can be lowered.
703+
self.use_tiling = False
704+
705+
# The minimal tile height and width for spatial tiling to be used
706+
self.tile_sample_min_height = 256
707+
self.tile_sample_min_width = 256
708+
709+
# The minimal distance between two spatial tiles
710+
self.tile_sample_stride_height = 192
711+
self.tile_sample_stride_width = 192
712+
713+
def enable_tiling(
714+
self,
715+
tile_sample_min_height: Optional[int] = None,
716+
tile_sample_min_width: Optional[int] = None,
717+
tile_sample_stride_height: Optional[float] = None,
718+
tile_sample_stride_width: Optional[float] = None,
719+
) -> None:
720+
r"""
721+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
722+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
723+
processing larger images.
724+
725+
Args:
726+
tile_sample_min_height (`int`, *optional*):
727+
The minimum height required for a sample to be separated into tiles across the height dimension.
728+
tile_sample_min_width (`int`, *optional*):
729+
The minimum width required for a sample to be separated into tiles across the width dimension.
730+
tile_sample_stride_height (`int`, *optional*):
731+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
732+
no tiling artifacts produced across the height dimension.
733+
tile_sample_stride_width (`int`, *optional*):
734+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
735+
artifacts produced across the width dimension.
736+
"""
737+
self.use_tiling = True
738+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
739+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
740+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
741+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
742+
743+
def disable_tiling(self) -> None:
744+
r"""
745+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
746+
decoding in one step.
747+
"""
748+
self.use_tiling = False
749+
733750
def clear_cache(self):
734751
def _count_conv3d(model):
735752
count = 0
@@ -785,7 +802,11 @@ def encode(
785802
The latent representations of the encoded videos. If `return_dict` is True, a
786803
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787804
"""
788-
h = self._encode(x)
805+
_, _, _, height, width = x.shape
806+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
807+
h = self.tiled_encode(x)
808+
else:
809+
h = self._encode(x)
789810
posterior = DiagonalGaussianDistribution(h)
790811
if not return_dict:
791812
return (posterior,)
@@ -826,12 +847,170 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826847
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827848
returned.
828849
"""
829-
decoded = self._decode(z).sample
850+
_, _, _, height, width = z.shape
851+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
852+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
853+
854+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
855+
decoded = self.tiled_decode(z).sample
856+
else:
857+
decoded = self._decode(z).sample
830858
if not return_dict:
831859
return (decoded,)
832860

833861
return DecoderOutput(sample=decoded)
834862

863+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
864+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
865+
for y in range(blend_extent):
866+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
867+
y / blend_extent
868+
)
869+
return b
870+
871+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
872+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
873+
for x in range(blend_extent):
874+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
875+
x / blend_extent
876+
)
877+
return b
878+
879+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
880+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
881+
for x in range(blend_extent):
882+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
883+
x / blend_extent
884+
)
885+
return b
886+
887+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
888+
r"""Encode a batch of images using a tiled encoder.
889+
890+
Args:
891+
x (`torch.Tensor`): Input batch of videos.
892+
893+
Returns:
894+
`torch.Tensor`:
895+
The latent representation of the encoded videos.
896+
"""
897+
_, _, num_frames, height, width = x.shape
898+
latent_height = height // self.spatial_compression_ratio
899+
latent_width = width // self.spatial_compression_ratio
900+
901+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
902+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
903+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
904+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
905+
906+
blend_height = tile_latent_min_height - tile_latent_stride_height
907+
blend_width = tile_latent_min_width - tile_latent_stride_width
908+
909+
# Split x into overlapping tiles and encode them separately.
910+
# The tiles have an overlap to avoid seams between tiles.
911+
rows = []
912+
for i in range(0, height, self.tile_sample_stride_height):
913+
row = []
914+
for j in range(0, width, self.tile_sample_stride_width):
915+
self.clear_cache()
916+
time = []
917+
frame_range = 1 + (num_frames - 1) // 4
918+
for k in range(frame_range):
919+
self._enc_conv_idx = [0]
920+
if k == 0:
921+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
922+
else:
923+
tile = x[
924+
:,
925+
:,
926+
1 + 4 * (k - 1) : 1 + 4 * k,
927+
i : i + self.tile_sample_min_height,
928+
j : j + self.tile_sample_min_width,
929+
]
930+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
931+
tile = self.quant_conv(tile)
932+
time.append(tile)
933+
row.append(torch.cat(time, dim=2))
934+
rows.append(row)
935+
936+
result_rows = []
937+
for i, row in enumerate(rows):
938+
result_row = []
939+
for j, tile in enumerate(row):
940+
# blend the above tile and the left tile
941+
# to the current tile and add the current tile to the result row
942+
if i > 0:
943+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
944+
if j > 0:
945+
tile = self.blend_h(row[j - 1], tile, blend_width)
946+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
947+
result_rows.append(torch.cat(result_row, dim=-1))
948+
949+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
950+
return enc
951+
952+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
953+
r"""
954+
Decode a batch of images using a tiled decoder.
955+
956+
Args:
957+
z (`torch.Tensor`): Input batch of latent vectors.
958+
return_dict (`bool`, *optional*, defaults to `True`):
959+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
960+
961+
Returns:
962+
[`~models.vae.DecoderOutput`] or `tuple`:
963+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
964+
returned.
965+
"""
966+
_, _, num_frames, height, width = z.shape
967+
sample_height = height * self.spatial_compression_ratio
968+
sample_width = width * self.spatial_compression_ratio
969+
970+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
971+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
972+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
973+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
974+
975+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
976+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
977+
978+
# Split z into overlapping tiles and decode them separately.
979+
# The tiles have an overlap to avoid seams between tiles.
980+
rows = []
981+
for i in range(0, height, tile_latent_stride_height):
982+
row = []
983+
for j in range(0, width, tile_latent_stride_width):
984+
self.clear_cache()
985+
time = []
986+
for k in range(num_frames):
987+
self._conv_idx = [0]
988+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
989+
tile = self.post_quant_conv(tile)
990+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
991+
time.append(decoded)
992+
row.append(torch.cat(time, dim=2))
993+
rows.append(row)
994+
995+
result_rows = []
996+
for i, row in enumerate(rows):
997+
result_row = []
998+
for j, tile in enumerate(row):
999+
# blend the above tile and the left tile
1000+
# to the current tile and add the current tile to the result row
1001+
if i > 0:
1002+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1003+
if j > 0:
1004+
tile = self.blend_h(row[j - 1], tile, blend_width)
1005+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1006+
result_rows.append(torch.cat(result_row, dim=-1))
1007+
1008+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1009+
1010+
if not return_dict:
1011+
return (dec,)
1012+
return DecoderOutput(sample=dec)
1013+
8351014
def forward(
8361015
self,
8371016
sample: torch.Tensor,

tests/models/autoencoders/test_models_autoencoder_wan.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import unittest
1717

18+
import torch
19+
1820
from diffusers import AutoencoderKLWan
1921
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
2022

@@ -44,9 +46,16 @@ def dummy_input(self):
4446
num_frames = 9
4547
num_channels = 3
4648
sizes = (16, 16)
47-
4849
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
50+
return {"sample": image}
4951

52+
@property
53+
def dummy_input_tiling(self):
54+
batch_size = 2
55+
num_frames = 9
56+
num_channels = 3
57+
sizes = (640, 480)
58+
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
5059
return {"sample": image}
5160

5261
@property
@@ -62,6 +71,42 @@ def prepare_init_args_and_inputs_for_common(self):
6271
inputs_dict = self.dummy_input
6372
return init_dict, inputs_dict
6473

74+
def prepare_init_args_and_inputs_for_tiling(self):
75+
init_dict = self.get_autoencoder_kl_wan_config()
76+
inputs_dict = self.dummy_input_tiling
77+
return init_dict, inputs_dict
78+
79+
def test_enable_disable_tiling(self):
80+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
81+
82+
torch.manual_seed(0)
83+
model = self.model_class(**init_dict).to(torch_device)
84+
85+
inputs_dict.update({"return_dict": False})
86+
87+
torch.manual_seed(0)
88+
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
89+
90+
torch.manual_seed(0)
91+
model.enable_tiling()
92+
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
93+
94+
self.assertLess(
95+
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
96+
0.5,
97+
"VAE tiling should not affect the inference results",
98+
)
99+
100+
torch.manual_seed(0)
101+
model.disable_tiling()
102+
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
103+
104+
self.assertEqual(
105+
output_without_tiling.detach().cpu().numpy().all(),
106+
output_without_tiling_2.detach().cpu().numpy().all(),
107+
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
108+
)
109+
65110
@unittest.skip("Gradient checkpointing has not been implemented yet")
66111
def test_gradient_checkpointing_is_applied(self):
67112
pass

0 commit comments

Comments
 (0)