Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 0db685f

Browse files
virginiafdezvirginiafdez
andauthored
Addition of non_quantized flag to enable non-quantised encoding for L… (#474)
* Addition of non_quantized flag to enable non-quantised encoding for LDM training on the VQVAE. * non_quantized > quantized Default true. --------- Co-authored-by: virginiafdez <[email protected]>
1 parent a473b5f commit 0db685f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

generative/networks/nets/vqvae.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,12 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
442442

443443
return reconstruction, quantization_losses
444444

445-
def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
445+
def encode_stage_2_inputs(self, x: torch.Tensor, quantized: bool = True) -> torch.Tensor:
446446
z = self.encode(x)
447447
e, _ = self.quantize(z)
448-
return e
448+
if quantized:
449+
return e
450+
return z
449451

450452
def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
451453
e, _ = self.quantize(z)

0 commit comments

Comments
 (0)