Skip to content

Commit 2c07d53

Browse files
author
tibuch
authored
Fix real-loss for bin_factor 1.
1 parent 46de611 commit 2c07d53

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

fit/modules/TRecTransformerModule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _real_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
9090
dst_flatten_coords=self.dst_flatten_coords, img_shape=self.hparams.img_shape)
9191
y_hat = torch.roll(torch.fft.irfftn(dft_pred, dim=[1, 2], s=2 * (self.hparams.img_shape,)),
9292
2 * (self.hparams.img_shape // 2,), (1, 2))
93-
return F.mse_loss(self.mask * y_hat, self.mask * target_real)
93+
return F.mse_loss(y_hat, target_real)
9494
else:
9595
dft_pred = convert_to_dft(fc=pred_fc, mag_min=mag_min, mag_max=mag_max,
9696
dst_flatten_coords=self.dst_flatten_coords, img_shape=self.hparams.img_shape)
@@ -206,7 +206,7 @@ def validation_epoch_end(self, outputs):
206206
bin_mse = [o['bin_mse'] for o in outputs]
207207
mean_val_mse = torch.mean(torch.stack(val_mse))
208208
mean_bin_mse = torch.mean(torch.stack(bin_mse))
209-
if self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < (self.hparams.alpha * mean_bin_mse):
209+
if self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < (self.hparams.alpha * mean_bin_mse) and self.bin_factor > 1:
210210
self.bin_count = 0
211211
self.bin_factor = max(1, self.bin_factor - 1)
212212
self.register_buffer('mask', psfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device))

0 commit comments

Comments
 (0)