@@ -459,7 +459,8 @@ def sample(
459459 latent = torch .stack ([self .autoencoder_resizer (i ) for i in decollate_batch (latent )], 0 )
460460 if save_intermediates :
461461 latent_intermediates = [
462- torch .stack ([self .autoencoder_resizer (i ) for i in decollate_batch (l )], 0 ) for l in latent_intermediates
462+ torch .stack ([self .autoencoder_resizer (i ) for i in decollate_batch (l )], 0 )
463+ for l in latent_intermediates
463464 ]
464465
465466 decode = autoencoder_model .decode_stage_2_outputs
@@ -592,13 +593,15 @@ def __call__(
592593 raise NotImplementedError (f"{ mode } condition is not supported" )
593594
594595 noisy_image = self .scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
595- down_block_res_samples , mid_block_res_sample = controlnet (
596- x = noisy_image , timesteps = timesteps , controlnet_cond = cn_cond
597- )
596+
598597 if mode == "concat" :
599598 noisy_image = torch .cat ([noisy_image , condition ], dim = 1 )
600599 condition = None
601600
601+ down_block_res_samples , mid_block_res_sample = controlnet (
602+ x = noisy_image , timesteps = timesteps , controlnet_cond = cn_cond , context = condition
603+ )
604+
602605 diffuse = diffusion_model
603606 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
604607 diffuse = partial (diffusion_model , seg = seg )
@@ -654,32 +657,32 @@ def sample(
654657 progress_bar = iter (scheduler .timesteps )
655658 intermediates = []
656659 for t in progress_bar :
660+ if mode == "concat" :
661+ model_input = torch .cat ([image , conditioning ], dim = 1 )
662+ context_ = None
663+ else :
664+ model_input = image
665+ context_ = conditioning
666+
657667 # 1. ControlNet forward
658668 down_block_res_samples , mid_block_res_sample = controlnet (
659- x = image , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), controlnet_cond = cn_cond
669+ x = model_input ,
670+ timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
671+ controlnet_cond = cn_cond ,
672+ context = context_ ,
660673 )
661674 # 2. predict noise model_output
662675 diffuse = diffusion_model
663676 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
664677 diffuse = partial (diffusion_model , seg = seg )
665678
666- if mode == "concat" :
667- model_input = torch .cat ([image , conditioning ], dim = 1 )
668- model_output = diffuse (
669- model_input ,
670- timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
671- context = None ,
672- down_block_additional_residuals = down_block_res_samples ,
673- mid_block_additional_residual = mid_block_res_sample ,
674- )
675- else :
676- model_output = diffuse (
677- image ,
678- timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
679- context = conditioning ,
680- down_block_additional_residuals = down_block_res_samples ,
681- mid_block_additional_residual = mid_block_res_sample ,
682- )
679+ model_output = diffuse (
680+ model_input ,
681+ timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
682+ context = context_ ,
683+ down_block_additional_residuals = down_block_res_samples ,
684+ mid_block_additional_residual = mid_block_res_sample ,
685+ )
683686
684687 # 3. compute previous image: x_t -> x_t-1
685688 image , _ = scheduler .step (model_output , t , image )
@@ -743,31 +746,30 @@ def get_likelihood(
743746 for t in progress_bar :
744747 timesteps = torch .full (inputs .shape [:1 ], t , device = inputs .device ).long ()
745748 noisy_image = self .scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
749+
750+ if mode == "concat" :
751+ noisy_image = torch .cat ([noisy_image , conditioning ], dim = 1 )
752+ conditioning = None
753+
746754 down_block_res_samples , mid_block_res_sample = controlnet (
747- x = noisy_image , timesteps = torch .Tensor ((t ,)).to (inputs .device ), controlnet_cond = cn_cond
755+ x = noisy_image ,
756+ timesteps = torch .Tensor ((t ,)).to (inputs .device ),
757+ controlnet_cond = cn_cond ,
758+ context = conditioning ,
748759 )
749760
750761 diffuse = diffusion_model
751762 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
752763 diffuse = partial (diffusion_model , seg = seg )
753764
754- if mode == "concat" :
755- noisy_image = torch .cat ([noisy_image , conditioning ], dim = 1 )
756- model_output = diffuse (
757- noisy_image ,
758- timesteps = timesteps ,
759- context = None ,
760- down_block_additional_residuals = down_block_res_samples ,
761- mid_block_additional_residual = mid_block_res_sample ,
762- )
763- else :
764- model_output = diffuse (
765- x = noisy_image ,
766- timesteps = timesteps ,
767- context = conditioning ,
768- down_block_additional_residuals = down_block_res_samples ,
769- mid_block_additional_residual = mid_block_res_sample ,
770- )
765+ model_output = diffuse (
766+ noisy_image ,
767+ timesteps = timesteps ,
768+ context = conditioning ,
769+ down_block_additional_residuals = down_block_res_samples ,
770+ mid_block_additional_residual = mid_block_res_sample ,
771+ )
772+
771773 # get the model's predicted mean, and variance if it is predicted
772774 if model_output .shape [1 ] == inputs .shape [1 ] * 2 and scheduler .variance_type in ["learned" , "learned_range" ]:
773775 model_output , predicted_variance = torch .split (model_output , inputs .shape [1 ], dim = 1 )
@@ -994,7 +996,8 @@ def sample(
994996 latent = torch .stack ([self .autoencoder_resizer (i ) for i in decollate_batch (latent )], 0 )
995997 if save_intermediates :
996998 latent_intermediates = [
997- torch .stack ([self .autoencoder_resizer (i ) for i in decollate_batch (l )], 0 ) for l in latent_intermediates
999+ torch .stack ([self .autoencoder_resizer (i ) for i in decollate_batch (l )], 0 )
1000+ for l in latent_intermediates
9981001 ]
9991002
10001003 decode = autoencoder_model .decode_stage_2_outputs
0 commit comments