diff --git a/cyto_dl/models/im2im/diffusion_autoencoder.py b/cyto_dl/models/im2im/diffusion_autoencoder.py index 6d3c4da2..8478af20 100644 --- a/cyto_dl/models/im2im/diffusion_autoencoder.py +++ b/cyto_dl/models/im2im/diffusion_autoencoder.py @@ -156,15 +156,21 @@ def forward(self, x_cond, x_diff): ) loss_weight = self._get_loss_weight(timesteps) - # latent is B x C x 1 - latent = self.semantic_encoder(x_cond).unsqueeze(2) + # Encode condition: (B, lat_dim) + latent = self.semantic_encoder(x_cond) + + # (B, 1, lat_dim) for cross-attention + # AdaGN will internally .squeeze(1) + condition = latent.unsqueeze(1) + noise_pred = self.inferer( inputs=x_diff, diffusion_model=self.autoencoder, noise=noise, timesteps=timesteps, - condition=latent, + condition=condition, ) + return noise, noise_pred, latent, loss_weight def _generate_image(self, noise, cond): @@ -188,7 +194,7 @@ def _generate_image(self, noise, cond): def save_example(self, stage, cond_img, diff_img): """Save the sequence of denoising steps.""" with torch.no_grad(): - cond = self.semantic_encoder(cond_img).unsqueeze(2) + cond = self.semantic_encoder(cond_img).unsqueeze(1) noise = torch.randn_like(diff_img, device=self.device) sample = self._generate_image(noise, cond) @@ -263,7 +269,7 @@ def generate_from_latent( sample = torch.cat( [ self._generate_image( - noise[start:stop], cond[start:stop].unsqueeze(2) + noise[start:stop], cond[start:stop].unsqueeze(1) ).squeeze(1) for start, stop in tqdm.tqdm(batch_indices, desc="Generating batch") ],