-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpretrain.py
112 lines (85 loc) · 3.79 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import random
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import ignite.distributed as idist
from ignite.engine import Engine, Events, State
from ignite.utils import convert_tensor
import utils
import models
import data
def main(local_rank, args):
seed = args.seed + local_rank
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = True
device = idist.device()
logger, tb_logger = utils.get_logger(args)
dataset = data.get_dataset(args.dataset, args.datadir, mode='pretrain')
loader = data.get_loader(args, dataset, mode='pretrain')
args.num_epochs = args.num_iters // len(loader['train']) + 1
model = models.get_model(args,
input_shape=dataset['input_shape'],
patch_size=dataset['patch_size'])
model = idist.auto_model(model, sync_bn=True)
optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=args.wd)
optimizer = idist.auto_optim(optimizer)
def training_step(engine, batch):
model.train()
batch = convert_tensor(batch, device=device, non_blocking=True)
outputs = model(batch)
optimizer.zero_grad()
outputs['loss'].backward()
optimizer.step()
return outputs
trainer = Engine(training_step)
if logger is not None:
trainer.logger = logger
trainer.tb_logger = tb_logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, utils.log)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=args.save_freq), utils.save_checkpoint, args,
model=model, optimizer=optimizer)
@trainer.on(Events.ITERATION_COMPLETED(once=args.num_iters+1000)) ##For stable termination
def terminate(engine):
print(f"-> terminate at iteration: {engine.state.iteration}")
engine.terminate()
trainer.run(loader['train'], max_epochs=args.num_epochs)
if tb_logger is not None:
tb_logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--logdir', type=str, required=True)
parser.add_argument('--dataset', type=str, default='mscoco')
parser.add_argument('--datadir', type=str, default='/data')
parser.add_argument('--num-iters', type=int, default=100000)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--wd', type=float, default=1e-4)
parser.add_argument('--num-workers', type=int, default=6)
parser.add_argument('--model', type=str, default='metamae')
parser.add_argument('--backbone', type=str, default='dabs')
parser.add_argument('--mask-ratio', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--embed-dim-dec', type=int, default=128)
parser.add_argument('--num-layer-dec', type=int, default=4)
parser.add_argument('--inner-lr', type=float, default=0.5)
parser.add_argument('--reg-weight', type=float, default=0.1)
parser.add_argument('--s-ratio', type=float, default=0.1)
parser.add_argument('--use-first-order', action='store_true')
parser.add_argument('--save-freq', type=int, default=10000)
parser.add_argument('--master-port', type=int, default=2223)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
utils.setup_config(args)
n = torch.cuda.device_count()
if n == 1:
with idist.Parallel() as parallel:
parallel.run(main, args)
else:
with idist.Parallel(backend='nccl', nproc_per_node=n, master_port=os.environ.get('MASTER_PORT', args.master_port)) as parallel:
parallel.run(main, args)