Skip to content

Commit

Permalink
Merge pull request #28 from yilmazkorkmaz1/main
Browse files Browse the repository at this point in the history
Fix the sorting in ground truth and generated images, also fixed the local description loading.
  • Loading branch information
drogozhang authored Feb 20, 2025
2 parents c26fd44 + d058335 commit fcfc892
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions evaluation/image_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,14 @@ def encode(image, model, transform):
for img_pair in tqdm(image_pairs):
gen_img_path = img_pair[0]
gt_img_path = img_pair[1]

gen_img = Image.open(gen_img_path).convert("RGB")
gt_img = Image.open(gt_img_path).convert("RGB")
gt_img_name = gt_img_path.split('/')[-1]
gt_caption = caption_dict[gt_img_name]
generated_features = encode(Image.open(gen_img_path).convert('RGB'), model, transform)
gt_features = encode(Image.open(gt_img_path).convert('RGB'), model, transform)
gt_caption = caption_dict[gt_img_path.split('/')[-2]][gt_img_name]

generated_features = encode(gen_img, model, transform)
gt_features = encode(gt_img, model, transform)
# get text CLIP embedding
text_features = clip.tokenize(gt_caption).to(args.device)
with torch.no_grad():
Expand Down Expand Up @@ -185,6 +189,8 @@ def get_final_turn_img(dir, img_type):
raise ValueError(f"Only one image but wrongly named, image name: {os.path.join(dir,img_name_list[0])}\n img_name_list: {img_name_list}")
else:
gen_list = [s for s in img_name_list if 'iter_' in s]
gen_list.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))

if len(gen_list) == 0:
raise ValueError(f"Empty gen list, image name: {dir}\n img_name_list: {img_name_list}")
return os.path.join(dir, gen_list[-1])
Expand All @@ -193,6 +199,7 @@ def get_final_turn_img(dir, img_type):
gt_list = [s for s in img_name_list if 'output' in s]
if len(gt_list) == 0:
raise ValueError(f"Empty gt list, image name: {dir}\n img_name_list: {img_name_list}")
gt_list.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))
return os.path.join(dir, gt_list[-1])
else:
raise ValueError(f"Image type is not expected, only support 'generated' and 'gt', but got {img_type}")
Expand All @@ -215,6 +222,7 @@ def get_all_turn_imgs(dir, img_type):
img_name_list = [s for s in img_name_list if '.png' in s or '.jpg' in s]
# remove mask
img_name_list = [s for s in img_name_list if 'mask' not in s]

if len(img_name_list) == 0:
raise ValueError(f"Empty img list, image name: {dir}\n img_name_list: {img_name_list}")
# all turn included the first edit
Expand All @@ -227,13 +235,16 @@ def get_all_turn_imgs(dir, img_type):
else:
gen_list = [os.path.join(dir, s) for s in img_name_list if '_1' in s]
gen_list.extend([os.path.join(dir, s) for s in img_name_list if 'inde_' in s])
gen_list.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))

if len(gen_list) == 0:
raise ValueError(f"Empty gen list, image name: {dir}\n img_name_list: {img_name_list}")
return gen_list
elif img_type == 'gt':
gt_list = [os.path.join(dir, s) for s in img_name_list if 'output' in s]
if len(gt_list) == 0:
raise ValueError(f"Empty gt list, image name: {dir}\n img_name_list: {img_name_list}")
gt_list.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))
return gt_list
else:
raise ValueError(f"Image type is not expected, only support 'generated' and 'gt', but got {img_type}")
Expand All @@ -256,7 +267,6 @@ def load_data(args):
if not os.path.isfile(os.path.join(args.gt_path, gt_img_id)):
gt_img_id_list.append(gt_img_id)

# 3. check if the directory names are same. (orders are not important)
if set(gen_img_id_list) != set(gt_img_id_list):
# print the difference
print("The directory names under generated path and gt path are not same!")
Expand All @@ -271,7 +281,6 @@ def load_data(args):
gen_img_dir = os.path.join(args.generated_path, img_id)
gt_img_dir = os.path.join(args.gt_path, img_id)
final_turn_pairs.append((get_final_turn_img(gen_img_dir, 'generated'), get_final_turn_img(gt_img_dir, 'gt')))

generated_all_turns = get_all_turn_imgs(gen_img_dir, 'generated')
gt_all_turns = get_all_turn_imgs(gt_img_dir, 'gt')
if len(generated_all_turns) != len(gt_all_turns):
Expand Down Expand Up @@ -341,6 +350,7 @@ def load_data(args):
with open(args.caption_path, 'r') as f:
caption_dict = json.load(f)


print(f"No. of all turn pairs: {len(all_turn_pairs)}")
print(f"No. of final turn pairs: {len(final_turn_pairs)}")
# check turns
Expand Down Expand Up @@ -405,9 +415,7 @@ def load_data(args):

if 'clip-t' in args.metric:
# eval_clip_i(args, all_turn_pairs, final_turn_pairs)
model, transform = clip.load("ViT-B/32", args.device)
# print("CLIP-T model loaded: ", model)
# print("Model device: ", args.device)
model, transform = clip.load("ViT-B/32", args.device)
final_turn_eval_score, final_turn_oracle_score = eval_clip_t(args, final_turn_pairs, model, transform)
print(f"Final turn CLIP-T: {final_turn_eval_score}")
print(f"Final turn CLIP-T Oracle: {final_turn_oracle_score}")
Expand All @@ -419,6 +427,7 @@ def load_data(args):
evaluated_metrics_dict['all_turn']['clip-t'] = all_turn_eval_score
evaluated_metrics_dict['all_turn']['clip-t_oracle'] = all_turn_oracle_score


print(evaluated_metrics_dict)
# seprately print in final turn and all turn.
for turn in ['final_turn', 'all_turn']:
Expand All @@ -437,7 +446,7 @@ def load_data(args):
print()
print('#'*11*len(metrics))

# check if args.save_path exists
# Load existing metrics if file exists
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)

Expand Down

0 comments on commit fcfc892

Please sign in to comment.