-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_group_label.py
103 lines (78 loc) · 4.09 KB
/
infer_group_label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import clip
from sklearn.metrics import classification_report
from tqdm import tqdm
from data.celeba import CelebA
from data.waterbirds import Waterbirds
import celeba_templates
import waterbirds_templates
def main(args):
model, preprocess = clip.load('RN50', 'cuda', jit=False) # RN50, RN101, RN50x4, ViT-B/32
crop = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])
transform = transforms.Compose([crop, preprocess])
if args.dataset == 'waterbirds':
data_dir = os.path.join(args.data_dir, 'waterbird_complete95_forest2water2')
train_dataset = Waterbirds(data_dir=data_dir, split='train', transform=transform)
templates = waterbirds_templates.templates
class_templates = waterbirds_templates.class_templates
class_keywords_all = waterbirds_templates.class_keywords_all
elif args.dataset == 'celeba':
data_dir = os.path.join(args.data_dir, 'celeba')
train_dataset = CelebA(data_dir=data_dir, split='train', transform=transform)
templates = celeba_templates.templates
class_templates = celeba_templates.class_templates
class_keywords_all = celeba_templates.class_keywords_all
else:
raise NotImplementedError
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, num_workers=4, drop_last=False)
temperature = 0.02 # redundant parameter
with torch.no_grad():
zeroshot_weights = []
for class_keywords in class_keywords_all:
texts = [template.format(class_template.format(class_keyword)) for template in templates for class_template in class_templates for class_keyword in class_keywords]
texts = clip.tokenize(texts).cuda()
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
preds_minor, preds, targets_minor = [], [], []
with torch.no_grad():
for (image, (target, target_g, target_s), _) in tqdm(train_dataloader):
image = image.cuda()
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = image_features @ zeroshot_weights / temperature
probs = logits.softmax(dim=-1).cpu()
conf, pred = torch.max(probs, dim=1)
if args.dataset == 'waterbirds':
# minor group if
# (target, target_s) == (0, 1): landbird on water background
# (target, target_s) == (1, 0): waterbird on land background
is_minor_pred = (((target == 0) & (pred == 1)) | ((target == 1) & (pred == 0))).long()
is_minor = (((target == 0) & (target_s == 1)) | ((target == 1) & (target_s == 0))).long()
if args.dataset == 'celeba':
# minor group if
# (target, target_s) == (1, 1): blond man
is_minor_pred = ((target == 1) & (pred == 1)).long()
is_minor = ((target == 1) & (target_s == 1)).long()
preds_minor.append(is_minor_pred)
preds.append(pred)
targets_minor.append(is_minor)
preds_minor, preds, targets_minor = torch.cat(preds_minor), torch.cat(preds), torch.cat(targets_minor)
print(classification_report(targets_minor, preds_minor))
# Save pseudo labels
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
torch.save(preds, args.save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='celeba', choices=['celeba', 'waterbirds'])
parser.add_argument('--data_dir', default='/data')
parser.add_argument('--save_path', default='./pseudo_bias/celeba.pt')
args = parser.parse_args()
main(args)