@@ -677,42 +677,7 @@ def __init__(
677
677
attn_scales : List [float ] = [],
678
678
temperal_downsample : List [bool ] = [False , True , True ],
679
679
dropout : float = 0.0 ,
680
- latents_mean : List [float ] = [
681
- - 0.7571 ,
682
- - 0.7089 ,
683
- - 0.9113 ,
684
- 0.1075 ,
685
- - 0.1745 ,
686
- 0.9653 ,
687
- - 0.1517 ,
688
- 1.5508 ,
689
- 0.4134 ,
690
- - 0.0715 ,
691
- 0.5517 ,
692
- - 0.3632 ,
693
- - 0.1922 ,
694
- - 0.9497 ,
695
- 0.2503 ,
696
- - 0.2921 ,
697
- ],
698
- latents_std : List [float ] = [
699
- 2.8184 ,
700
- 1.4541 ,
701
- 2.3275 ,
702
- 2.6558 ,
703
- 1.2196 ,
704
- 1.7708 ,
705
- 2.6052 ,
706
- 2.0743 ,
707
- 3.2687 ,
708
- 2.1526 ,
709
- 2.8652 ,
710
- 1.5579 ,
711
- 1.6382 ,
712
- 1.1253 ,
713
- 2.8251 ,
714
- 1.9160 ,
715
- ],
680
+ spatial_compression_ratio : int = 8 ,
716
681
) -> None :
717
682
super ().__init__ ()
718
683
@@ -730,6 +695,58 @@ def __init__(
730
695
base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temperal_upsample , dropout
731
696
)
732
697
698
+ self .spatial_compression_ratio = spatial_compression_ratio
699
+
700
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
701
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
702
+ # intermediate tiles together, the memory requirement can be lowered.
703
+ self .use_tiling = False
704
+
705
+ # The minimal tile height and width for spatial tiling to be used
706
+ self .tile_sample_min_height = 256
707
+ self .tile_sample_min_width = 256
708
+
709
+ # The minimal distance between two spatial tiles
710
+ self .tile_sample_stride_height = 192
711
+ self .tile_sample_stride_width = 192
712
+
713
+ def enable_tiling (
714
+ self ,
715
+ tile_sample_min_height : Optional [int ] = None ,
716
+ tile_sample_min_width : Optional [int ] = None ,
717
+ tile_sample_stride_height : Optional [float ] = None ,
718
+ tile_sample_stride_width : Optional [float ] = None ,
719
+ ) -> None :
720
+ r"""
721
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
722
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
723
+ processing larger images.
724
+
725
+ Args:
726
+ tile_sample_min_height (`int`, *optional*):
727
+ The minimum height required for a sample to be separated into tiles across the height dimension.
728
+ tile_sample_min_width (`int`, *optional*):
729
+ The minimum width required for a sample to be separated into tiles across the width dimension.
730
+ tile_sample_stride_height (`int`, *optional*):
731
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
732
+ no tiling artifacts produced across the height dimension.
733
+ tile_sample_stride_width (`int`, *optional*):
734
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
735
+ artifacts produced across the width dimension.
736
+ """
737
+ self .use_tiling = True
738
+ self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
739
+ self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
740
+ self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
741
+ self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
742
+
743
+ def disable_tiling (self ) -> None :
744
+ r"""
745
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
746
+ decoding in one step.
747
+ """
748
+ self .use_tiling = False
749
+
733
750
def clear_cache (self ):
734
751
def _count_conv3d (model ):
735
752
count = 0
@@ -785,7 +802,11 @@ def encode(
785
802
The latent representations of the encoded videos. If `return_dict` is True, a
786
803
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
787
804
"""
788
- h = self ._encode (x )
805
+ _ , _ , _ , height , width = x .shape
806
+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
807
+ h = self .tiled_encode (x )
808
+ else :
809
+ h = self ._encode (x )
789
810
posterior = DiagonalGaussianDistribution (h )
790
811
if not return_dict :
791
812
return (posterior ,)
@@ -826,12 +847,170 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
826
847
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
827
848
returned.
828
849
"""
829
- decoded = self ._decode (z ).sample
850
+ _ , _ , _ , height , width = z .shape
851
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
852
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
853
+
854
+ if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
855
+ decoded = self .tiled_decode (z ).sample
856
+ else :
857
+ decoded = self ._decode (z ).sample
830
858
if not return_dict :
831
859
return (decoded ,)
832
860
833
861
return DecoderOutput (sample = decoded )
834
862
863
+ def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
864
+ blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
865
+ for y in range (blend_extent ):
866
+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
867
+ y / blend_extent
868
+ )
869
+ return b
870
+
871
+ def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
872
+ blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
873
+ for x in range (blend_extent ):
874
+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
875
+ x / blend_extent
876
+ )
877
+ return b
878
+
879
+ def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
880
+ blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
881
+ for x in range (blend_extent ):
882
+ b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
883
+ x / blend_extent
884
+ )
885
+ return b
886
+
887
+ def tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
888
+ r"""Encode a batch of images using a tiled encoder.
889
+
890
+ Args:
891
+ x (`torch.Tensor`): Input batch of videos.
892
+
893
+ Returns:
894
+ `torch.Tensor`:
895
+ The latent representation of the encoded videos.
896
+ """
897
+ _ , _ , num_frames , height , width = x .shape
898
+ latent_height = height // self .spatial_compression_ratio
899
+ latent_width = width // self .spatial_compression_ratio
900
+
901
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
902
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
903
+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
904
+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
905
+
906
+ blend_height = tile_latent_min_height - tile_latent_stride_height
907
+ blend_width = tile_latent_min_width - tile_latent_stride_width
908
+
909
+ # Split x into overlapping tiles and encode them separately.
910
+ # The tiles have an overlap to avoid seams between tiles.
911
+ rows = []
912
+ for i in range (0 , height , self .tile_sample_stride_height ):
913
+ row = []
914
+ for j in range (0 , width , self .tile_sample_stride_width ):
915
+ self .clear_cache ()
916
+ time = []
917
+ frame_range = 1 + (num_frames - 1 ) // 4
918
+ for k in range (frame_range ):
919
+ self ._enc_conv_idx = [0 ]
920
+ if k == 0 :
921
+ tile = x [:, :, :1 , i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
922
+ else :
923
+ tile = x [
924
+ :,
925
+ :,
926
+ 1 + 4 * (k - 1 ) : 1 + 4 * k ,
927
+ i : i + self .tile_sample_min_height ,
928
+ j : j + self .tile_sample_min_width ,
929
+ ]
930
+ tile = self .encoder (tile , feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
931
+ tile = self .quant_conv (tile )
932
+ time .append (tile )
933
+ row .append (torch .cat (time , dim = 2 ))
934
+ rows .append (row )
935
+
936
+ result_rows = []
937
+ for i , row in enumerate (rows ):
938
+ result_row = []
939
+ for j , tile in enumerate (row ):
940
+ # blend the above tile and the left tile
941
+ # to the current tile and add the current tile to the result row
942
+ if i > 0 :
943
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
944
+ if j > 0 :
945
+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
946
+ result_row .append (tile [:, :, :, :tile_latent_stride_height , :tile_latent_stride_width ])
947
+ result_rows .append (torch .cat (result_row , dim = - 1 ))
948
+
949
+ enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
950
+ return enc
951
+
952
+ def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
953
+ r"""
954
+ Decode a batch of images using a tiled decoder.
955
+
956
+ Args:
957
+ z (`torch.Tensor`): Input batch of latent vectors.
958
+ return_dict (`bool`, *optional*, defaults to `True`):
959
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
960
+
961
+ Returns:
962
+ [`~models.vae.DecoderOutput`] or `tuple`:
963
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
964
+ returned.
965
+ """
966
+ _ , _ , num_frames , height , width = z .shape
967
+ sample_height = height * self .spatial_compression_ratio
968
+ sample_width = width * self .spatial_compression_ratio
969
+
970
+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
971
+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
972
+ tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
973
+ tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
974
+
975
+ blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
976
+ blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
977
+
978
+ # Split z into overlapping tiles and decode them separately.
979
+ # The tiles have an overlap to avoid seams between tiles.
980
+ rows = []
981
+ for i in range (0 , height , tile_latent_stride_height ):
982
+ row = []
983
+ for j in range (0 , width , tile_latent_stride_width ):
984
+ self .clear_cache ()
985
+ time = []
986
+ for k in range (num_frames ):
987
+ self ._conv_idx = [0 ]
988
+ tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
989
+ tile = self .post_quant_conv (tile )
990
+ decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
991
+ time .append (decoded )
992
+ row .append (torch .cat (time , dim = 2 ))
993
+ rows .append (row )
994
+
995
+ result_rows = []
996
+ for i , row in enumerate (rows ):
997
+ result_row = []
998
+ for j , tile in enumerate (row ):
999
+ # blend the above tile and the left tile
1000
+ # to the current tile and add the current tile to the result row
1001
+ if i > 0 :
1002
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
1003
+ if j > 0 :
1004
+ tile = self .blend_h (row [j - 1 ], tile , blend_width )
1005
+ result_row .append (tile [:, :, :, : self .tile_sample_stride_height , : self .tile_sample_stride_width ])
1006
+ result_rows .append (torch .cat (result_row , dim = - 1 ))
1007
+
1008
+ dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
1009
+
1010
+ if not return_dict :
1011
+ return (dec ,)
1012
+ return DecoderOutput (sample = dec )
1013
+
835
1014
def forward (
836
1015
self ,
837
1016
sample : torch .Tensor ,
0 commit comments