forked from paulxiong/SimCLR-4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
117 lines (89 loc) · 3.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import torch
import torchvision
import argparse
from torch.utils.tensorboard import SummaryWriter
apex = False
try:
from apex import amp
apex = True
except ImportError:
print(
"Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
)
from model import load_model, save_model
from modules import NT_Xent
from modules.transformations import TransformsSimCLR
from utils import mask_correlated_samples, post_config_hook
#### pass configuration
from experiment import ex
def train(args, train_loader, model, criterion, optimizer, writer):
loss_epoch = 0
for step, ((x_i, x_j), _) in enumerate(train_loader):
optimizer.zero_grad()
x_i = x_i.to(args.device)
x_j = x_j.to(args.device)
# positive pair, with encoding
h_i, z_i = model(x_i)
h_j, z_j = model(x_j)
loss = criterion(z_i, z_j)
if apex and args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
if step % 50 == 0:
print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")
writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
loss_epoch += loss.item()
args.global_step += 1
return loss_epoch
@ex.automain
def main(_run, _log):
args = argparse.Namespace(**_run.config)
args = post_config_hook(args, _run)
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root = "./datasets"
train_sampler = None
if args.dataset == "STL10":
train_dataset = torchvision.datasets.STL10(
root, split="unlabeled", download=True, transform=TransformsSimCLR()
)
elif args.dataset == "CIFAR10":
train_dataset = torchvision.datasets.CIFAR10(
root, download=True, transform=TransformsSimCLR()
)
else:
raise NotImplementedError
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
drop_last=True,
num_workers=args.workers,
sampler=train_sampler,
)
model, optimizer, scheduler = load_model(args, train_loader)
tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
os.makedirs(tb_dir)
writer = SummaryWriter(log_dir=tb_dir)
mask = mask_correlated_samples(args)
criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)
args.global_step = 0
args.current_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
lr = optimizer.param_groups[0]['lr']
loss_epoch = train(args, train_loader, model, criterion, optimizer, writer)
if scheduler:
scheduler.step()
if epoch % 10 == 0:
save_model(args, model, optimizer)
writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
writer.add_scalar("Misc/learning_rate", lr, epoch)
print(
f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
)
args.current_epoch += 1
## end training
save_model(args, model, optimizer)