Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 866e282

Browse files
virginiafdezvirginiafdez
andauthored
Modify controlnet inferer to pass the same conditioning as the one th… (#477)
* Modify controlnet inferer to pass the same conditioning as the one the diffusion model is getting. Modification of tests accordingly. * Uncommented controlnet inferer tests, fixed them. These should be running now. * Re-formatting the test script, fix naming issues. --------- Co-authored-by: virginiafdez <[email protected]>
1 parent 269d61d commit 866e282

File tree

4 files changed

+94
-77
lines changed

4 files changed

+94
-77
lines changed

generative/inferers/inferer.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ def sample(
459459
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
460460
if save_intermediates:
461461
latent_intermediates = [
462-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
462+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
463+
for l in latent_intermediates
463464
]
464465

465466
decode = autoencoder_model.decode_stage_2_outputs
@@ -592,13 +593,15 @@ def __call__(
592593
raise NotImplementedError(f"{mode} condition is not supported")
593594

594595
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
595-
down_block_res_samples, mid_block_res_sample = controlnet(
596-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
597-
)
596+
598597
if mode == "concat":
599598
noisy_image = torch.cat([noisy_image, condition], dim=1)
600599
condition = None
601600

601+
down_block_res_samples, mid_block_res_sample = controlnet(
602+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
603+
)
604+
602605
diffuse = diffusion_model
603606
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
604607
diffuse = partial(diffusion_model, seg=seg)
@@ -654,32 +657,32 @@ def sample(
654657
progress_bar = iter(scheduler.timesteps)
655658
intermediates = []
656659
for t in progress_bar:
660+
if mode == "concat":
661+
model_input = torch.cat([image, conditioning], dim=1)
662+
context_ = None
663+
else:
664+
model_input = image
665+
context_ = conditioning
666+
657667
# 1. ControlNet forward
658668
down_block_res_samples, mid_block_res_sample = controlnet(
659-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
669+
x=model_input,
670+
timesteps=torch.Tensor((t,)).to(input_noise.device),
671+
controlnet_cond=cn_cond,
672+
context=context_,
660673
)
661674
# 2. predict noise model_output
662675
diffuse = diffusion_model
663676
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
664677
diffuse = partial(diffusion_model, seg=seg)
665678

666-
if mode == "concat":
667-
model_input = torch.cat([image, conditioning], dim=1)
668-
model_output = diffuse(
669-
model_input,
670-
timesteps=torch.Tensor((t,)).to(input_noise.device),
671-
context=None,
672-
down_block_additional_residuals=down_block_res_samples,
673-
mid_block_additional_residual=mid_block_res_sample,
674-
)
675-
else:
676-
model_output = diffuse(
677-
image,
678-
timesteps=torch.Tensor((t,)).to(input_noise.device),
679-
context=conditioning,
680-
down_block_additional_residuals=down_block_res_samples,
681-
mid_block_additional_residual=mid_block_res_sample,
682-
)
679+
model_output = diffuse(
680+
model_input,
681+
timesteps=torch.Tensor((t,)).to(input_noise.device),
682+
context=context_,
683+
down_block_additional_residuals=down_block_res_samples,
684+
mid_block_additional_residual=mid_block_res_sample,
685+
)
683686

684687
# 3. compute previous image: x_t -> x_t-1
685688
image, _ = scheduler.step(model_output, t, image)
@@ -743,31 +746,30 @@ def get_likelihood(
743746
for t in progress_bar:
744747
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
745748
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
749+
750+
if mode == "concat":
751+
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
752+
conditioning = None
753+
746754
down_block_res_samples, mid_block_res_sample = controlnet(
747-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
755+
x=noisy_image,
756+
timesteps=torch.Tensor((t,)).to(inputs.device),
757+
controlnet_cond=cn_cond,
758+
context=conditioning,
748759
)
749760

750761
diffuse = diffusion_model
751762
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
752763
diffuse = partial(diffusion_model, seg=seg)
753764

754-
if mode == "concat":
755-
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
756-
model_output = diffuse(
757-
noisy_image,
758-
timesteps=timesteps,
759-
context=None,
760-
down_block_additional_residuals=down_block_res_samples,
761-
mid_block_additional_residual=mid_block_res_sample,
762-
)
763-
else:
764-
model_output = diffuse(
765-
x=noisy_image,
766-
timesteps=timesteps,
767-
context=conditioning,
768-
down_block_additional_residuals=down_block_res_samples,
769-
mid_block_additional_residual=mid_block_res_sample,
770-
)
765+
model_output = diffuse(
766+
noisy_image,
767+
timesteps=timesteps,
768+
context=conditioning,
769+
down_block_additional_residuals=down_block_res_samples,
770+
mid_block_additional_residual=mid_block_res_sample,
771+
)
772+
771773
# get the model's predicted mean, and variance if it is predicted
772774
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
773775
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
@@ -994,7 +996,8 @@ def sample(
994996
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
995997
if save_intermediates:
996998
latent_intermediates = [
997-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
999+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1000+
for l in latent_intermediates
9981001
]
9991002

10001003
decode = autoencoder_model.decode_stage_2_outputs

generative/networks/schedulers/ddim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def reversed_step(
257257

258258
# 2. compute alphas, betas at timestep t+1
259259
alpha_prod_t = self.alphas_cumprod[timestep]
260-
alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod
260+
alpha_prod_t_next = (
261+
self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod
262+
)
261263

262264
beta_prod_t = 1 - alpha_prod_t
263265

tests/test_controlnet_inferers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
537537

538538
@parameterized.expand(CNDM_TEST_CASES)
539539
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
540-
model_params["with_conditioning"] = True
541-
model_params["cross_attention_dim"] = 3
540+
model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True
541+
model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3
542542
model = DiffusionModelUNet(**model_params)
543543
controlnet = ControlNet(**controlnet_params)
544544
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -603,10 +603,12 @@ def test_normal_cdf(self):
603603
def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape):
604604
# copy the model_params dict to prevent from modifying test cases
605605
model_params = model_params.copy()
606+
controlnet_params = controlnet_params.copy()
606607
n_concat_channel = 2
607608
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
608-
model_params["cross_attention_dim"] = None
609-
model_params["with_conditioning"] = False
609+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
610+
model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None
611+
model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False
610612
model = DiffusionModelUNet(**model_params)
611613
device = "cuda:0" if torch.cuda.is_available() else "cpu"
612614
model.to(device)
@@ -986,8 +988,10 @@ def test_prediction_shape_conditioned_concat(
986988
if ae_model_type == "SPADEAutoencoderKL":
987989
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
988990
stage_2_params = stage_2_params.copy()
991+
controlnet_params = controlnet_params.copy()
989992
n_concat_channel = 3
990993
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
994+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
991995
if dm_model_type == "SPADEDiffusionModelUNet":
992996
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
993997
else:
@@ -1066,8 +1070,10 @@ def test_sample_shape_conditioned_concat(
10661070
if ae_model_type == "SPADEAutoencoderKL":
10671071
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
10681072
stage_2_params = stage_2_params.copy()
1073+
controlnet_params = controlnet_params.copy()
10691074
n_concat_channel = 3
10701075
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1076+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
10711077
if dm_model_type == "SPADEDiffusionModelUNet":
10721078
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
10731079
else:

tutorials/generative/2d_controlnet/2d_controlnet.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@
211211
inferer = DiffusionInferer(scheduler)
212212

213213

214-
215214
# %% [markdown]
216215
# ### Run training
217216
#
@@ -348,10 +347,14 @@
348347
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
349348
).long()
350349

351-
noise_pred = controlnet_inferer(inputs = images, diffusion_model = model,
352-
controlnet = controlnet, noise = noise,
353-
timesteps = timesteps,
354-
cn_cond = masks, )
350+
noise_pred = controlnet_inferer(
351+
inputs=images,
352+
diffusion_model=model,
353+
controlnet=controlnet,
354+
noise=noise,
355+
timesteps=timesteps,
356+
cn_cond=masks,
357+
)
355358

356359
loss = F.mse_loss(noise_pred.float(), noise.float())
357360

@@ -378,13 +381,16 @@
378381
0, controlnet_inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
379382
).long()
380383

381-
noise_pred = controlnet_inferer(inputs = images, diffusion_model = model,
382-
controlnet = controlnet, noise = noise,
383-
timesteps = timesteps,
384-
cn_cond = masks, )
384+
noise_pred = controlnet_inferer(
385+
inputs=images,
386+
diffusion_model=model,
387+
controlnet=controlnet,
388+
noise=noise,
389+
timesteps=timesteps,
390+
cn_cond=masks,
391+
)
385392
val_loss = F.mse_loss(noise_pred.float(), noise.float())
386393

387-
388394
val_epoch_loss += val_loss.item()
389395

390396
progress_bar.set_postfix({"val_loss": val_epoch_loss / (step + 1)})
@@ -398,30 +404,30 @@
398404
with autocast(enabled=True):
399405
noise = torch.randn((1, 1, 64, 64)).to(device)
400406
sample = controlnet_inferer.sample(
401-
input_noise = noise,
402-
diffusion_model = model,
403-
controlnet = controlnet,
404-
cn_cond = masks[0, None, ...],
405-
scheduler = scheduler,
407+
input_noise=noise,
408+
diffusion_model=model,
409+
controlnet=controlnet,
410+
cn_cond=masks[0, None, ...],
411+
scheduler=scheduler,
406412
)
407413

408414
# Without using an inferer:
409-
# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110)
410-
# progress_bar_sampling.set_description("sampling...")
411-
# sample = torch.randn((1, 1, 64, 64)).to(device)
412-
# for t in progress_bar_sampling:
413-
# with torch.no_grad():
414-
# with autocast(enabled=True):
415-
# down_block_res_samples, mid_block_res_sample = controlnet(
416-
# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...]
417-
# )
418-
# noise_pred = model(
419-
# sample,
420-
# timesteps=torch.Tensor((t,)).to(device),
421-
# down_block_additional_residuals=down_block_res_samples,
422-
# mid_block_additional_residual=mid_block_res_sample,
423-
# )
424-
# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample)
415+
# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110)
416+
# progress_bar_sampling.set_description("sampling...")
417+
# sample = torch.randn((1, 1, 64, 64)).to(device)
418+
# for t in progress_bar_sampling:
419+
# with torch.no_grad():
420+
# with autocast(enabled=True):
421+
# down_block_res_samples, mid_block_res_sample = controlnet(
422+
# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...]
423+
# )
424+
# noise_pred = model(
425+
# sample,
426+
# timesteps=torch.Tensor((t,)).to(device),
427+
# down_block_additional_residuals=down_block_res_samples,
428+
# mid_block_additional_residual=mid_block_res_sample,
429+
# )
430+
# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample)
425431

426432
plt.subplots(1, 2, figsize=(4, 2))
427433
plt.subplot(1, 2, 1)

0 commit comments

Comments
 (0)