-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconfig.py
101 lines (88 loc) · 4.21 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.optim as optim
from deepvac import AttrDict, new
from data.dataloader import DBTrainDataset, DBTrainCocoDataset, DBTestDataset
from modules.model_db import Resnet18DB, Mobilenetv3LargeDB
from modules.loss import DBLoss
config = new('DBNetTrain')
## ------------------ common ------------------
config.core.DBNetTrain.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config.core.DBNetTrain.output_dir = 'output'
config.core.DBNetTrain.log_every = 100
config.core.DBNetTrain.disable_git = True
config.core.DBNetTrain.model_reinterpret_cast = True
config.core.DBNetTrain.cast_state_dict_strict = True
# config.core.DBNetTrain.jit_model_path = "./output/script.pt"
## -------------------- training ------------------
## train runtime
config.core.DBNetTrain.epoch_num = 200
config.core.DBNetTrain.save_num = 1
## -------------------- tensorboard ------------------
#config.core.DBNetTrain.tensorboard_port = "6007"
#config.core.DBNetTrain.tensorboard_ip = None
## -------------------- script and quantize ------------------
config.cast.ScriptCast = AttrDict()
config.cast.ScriptCast.model_dir = "./script.pt"
# config.cast.ScriptCast.static_quantize_dir = "./script.sq" # unsupported op nn.ConvTranspose2d for now
# config.cast.ScriptCast.dynamic_quantize_dir = "./quantize.sq"
## -------------------- net and criterion ------------------
config.arch = "resnet18"
if config.arch == "resnet18":
config.core.DBNetTrain.net = Resnet18DB()
elif config.arch == "mv3large":
config.core.DBNetTrain.net = Mobilenetv3LargeDB()
else:
raise Exception("Architecture {} is not supported!".format(config.arch))
config.core.DBNetTrain.criterion = DBLoss(config)
## -------------------- optimizer and scheduler ------------------
config.core.DBNetTrain.optimizer = torch.optim.Adam(config.core.DBNetTrain.net.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=5e-4)
lambda_lr = lambda epoch: round ((1 - epoch/config.core.DBNetTrain.epoch_num) ** 0.9, 8)
config.core.DBNetTrain.scheduler = optim.lr_scheduler.LambdaLR(config.core.DBNetTrain.optimizer, lr_lambda=lambda_lr)
## -------------------- loader ------------------
config.sample_path = 'your train images dir'
config.label_path = 'your train coco json path'
config.img_size = 640
config.datasets.DBTrainCocoDataset = AttrDict()
config.datasets.DBTrainCocoDataset.shrink_ratio = 0.4
config.datasets.DBTrainCocoDataset.thresh_min = 0.3
config.datasets.DBTrainCocoDataset.thresh_max = 0.7
config.core.DBNetTrain.batch_size = 8
config.core.DBNetTrain.num_workers = 4
config.core.DBNetTrain.train_dataset = DBTrainCocoDataset(config, config.sample_path, config.label_path, config.img_size)
config.core.DBNetTrain.train_loader = torch.utils.data.DataLoader(
dataset = config.core.DBNetTrain.train_dataset,
batch_size = config.core.DBNetTrain.batch_size,
shuffle = True,
num_workers = config.core.DBNetTrain.num_workers,
pin_memory = True,
sampler = None
)
## -------------------- val ------------------
config.sample_path = 'your val images dir'
config.label_path = 'your val coco json path'
config.img_size = 640
config.core.DBNetTrain.val_dataset = DBTrainCocoDataset(config, config.sample_path, config.label_path, config.img_size)
config.core.DBNetTrain.val_loader = torch.utils.data.DataLoader(
dataset = config.core.DBNetTrain.val_dataset,
batch_size = 1,
shuffle = False,
num_workers = 0,
pin_memory = True
)
## -------------------- test ------------------
config.core.DBNetTest = config.core.DBNetTrain.clone()
config.core.DBNetTest.model_path = 'output/disable_git/model__2021-06-30-04-55__acc_0__epoch_34__step_182__lr_0.00042280517.pth'
# config.core.DBNetTest.jit_model_path = 'your torchscript model path'
config.core.DBNetTest.is_output_polygon = True
config.sample_path = 'your test images path'
config.core.DBNetTest.test_dataset = DBTestDataset(config, config.sample_path, long_size = 1280)
config.core.DBNetTest.test_loader = torch.utils.data.DataLoader(
dataset = config.core.DBNetTest.test_dataset,
batch_size = 1,
shuffle = False,
num_workers = 0,
pin_memory = True
)
## ------------------------- DDP ------------------
config.core.DBNetTrain.dist_url = 'tcp://localhost:27030'
config.core.DBNetTrain.world_size = 1