@@ -90,7 +90,7 @@ def _real_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
9090 dst_flatten_coords = self .dst_flatten_coords , img_shape = self .hparams .img_shape )
9191 y_hat = torch .roll (torch .fft .irfftn (dft_pred , dim = [1 , 2 ], s = 2 * (self .hparams .img_shape ,)),
9292 2 * (self .hparams .img_shape // 2 ,), (1 , 2 ))
93- return F .mse_loss (self . mask * y_hat , self . mask * target_real )
93+ return F .mse_loss (y_hat , target_real )
9494 else :
9595 dft_pred = convert_to_dft (fc = pred_fc , mag_min = mag_min , mag_max = mag_max ,
9696 dst_flatten_coords = self .dst_flatten_coords , img_shape = self .hparams .img_shape )
@@ -206,7 +206,7 @@ def validation_epoch_end(self, outputs):
206206 bin_mse = [o ['bin_mse' ] for o in outputs ]
207207 mean_val_mse = torch .mean (torch .stack (val_mse ))
208208 mean_bin_mse = torch .mean (torch .stack (bin_mse ))
209- if self .bin_count > self .hparams .bin_factor_cd and mean_val_mse < (self .hparams .alpha * mean_bin_mse ):
209+ if self .bin_count > self .hparams .bin_factor_cd and mean_val_mse < (self .hparams .alpha * mean_bin_mse ) and self . bin_factor > 1 :
210210 self .bin_count = 0
211211 self .bin_factor = max (1 , self .bin_factor - 1 )
212212 self .register_buffer ('mask' , psfft (self .bin_factor , pixel_res = self .hparams .img_shape ).to (self .device ))
0 commit comments