Skip to content

Commit ece9211

Browse files
author
Alex Damian
committed
Added epsilon cutoff
PULSE only returns an image when the downloss is within epsilon
1 parent a413e9e commit ece9211

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

PULSE.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def forward(self, ref_im,
124124
loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
125125

126126
min_loss = np.inf
127+
min_l2 = np.inf
127128
best_summary = ""
128129
start_t = time.time()
129130
gen_im = None
@@ -156,6 +157,11 @@ def forward(self, ref_im,
156157
[f'{x}: {y:.4f}' for x, y in loss_dict.items()])
157158
best_im = gen_im.clone()
158159

160+
loss_l2 = loss_dict['L2']
161+
162+
if(loss_l2 < min_l2):
163+
min_l2 = loss_l2
164+
159165
# Save intermediate HR and LR images
160166
if(save_intermediate):
161167
yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
@@ -167,5 +173,7 @@ def forward(self, ref_im,
167173
total_t = time.time()-start_t
168174
current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
169175
if self.verbose: print(best_summary+current_info)
170-
171-
yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
176+
if(min_l2 <= eps):
177+
yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
178+
else:
179+
print("Could not find a face that downscales correctly within epsilon")

run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __getitem__(self, idx):
3636
#PULSE arguments
3737
parser.add_argument('-seed', type=int, help='manual seed to use')
3838
parser.add_argument('-loss_str', type=str, default="100*L2+0.05*GEOCROSS", help='Loss function to use')
39-
parser.add_argument('-eps', type=float, default=1e-3, help='Target for downscaling loss (L2)')
39+
parser.add_argument('-eps', type=float, default=2e-3, help='Target for downscaling loss (L2)')
4040
parser.add_argument('-noise_type', type=str, default='trainable', help='zero, fixed, or trainable')
4141
parser.add_argument('-num_trainable_noise_layers', type=int, default=5, help='Number of noise layers to optimize')
4242
parser.add_argument('-tile_latent', action='store_true', help='Whether to forcibly tile the same latent 18 times')

0 commit comments

Comments
 (0)