-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
71 lines (59 loc) · 2.44 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sys
import os
import math
import time
import numpy as np
import cv2
import torch
from torch import nn
from torch import optim
from deepvac import LOG, DeepvacTrain
from modules.utils_IOU_eval import IOUEval
class LiteHRNetTrain(DeepvacTrain):
def __init__(self, deepvac_config):
super(LiteHRNetTrain, self).__init__(deepvac_config)
self.config.epoch_loss = []
def train(self):
self.iou_eval_val = IOUEval(self.config.cls_num)
self.iou_eval_train = IOUEval(self.config.cls_num)
for i, loader in enumerate(self.config.train_loader_list):
self.config.train_loader = loader
super(LiteHRNetTrain, self).train()
#only save model for last loader
def doSave(self):
if not self.config.train_loader.is_last_loader:
return
super(LiteHRNetTrain, self).doSave()
def postIter(self):
if not self.config.train_loader.is_last_loader:
return
self.config.epoch_loss.append(self.config.loss.item())
if self.config.phase == 'TRAIN':
self.iou_eval_train.addBatch(self.config.output.max(1)[1].data.cpu().numpy(), self.config.target.data.cpu().numpy())
else:
self.iou_eval_val.addBatch(self.config.output.max(1)[1].data.cpu().numpy(), self.config.target.data.cpu().numpy())
def preEpoch(self):
self.config.epoch_loss = []
def postEpoch(self):
if not self.config.train_loader.is_last_loader:
return
average_epoch_loss = sum(self.config.epoch_loss) / len(self.config.epoch_loss)
if self.config.phase == 'TRAIN':
overall_acc, per_class_acc, per_class_iu, mIOU = self.iou_eval_train.getMetric()
else:
overall_acc, per_class_acc, per_class_iu, mIOU = self.iou_eval_val.getMetric()
self.config.acc = mIOU
LOG.logI("Epoch : {} Details".format(self.config.epoch))
LOG.logI("\nEpoch No.: %d\t%s Loss = %.4f\t %s mIOU = %.4f\t" % (self.config.epoch, self.config.phase, average_epoch_loss, self.config.phase, mIOU))
def doSchedule(self):
if not self.config.train_loader.is_last_loader:
return
self.config.scheduler.step()
def doLoss(self):
if not self.config.is_train:
return
self.config.loss = self.config.criterion(self.config.output, self.config.target)
if __name__ == "__main__":
from config import config
train = LiteHRNetTrain(config)
train()