-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
39 lines (30 loc) · 1.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from deepvac import DeepvacTrain, LOG
from modules.utils import cal_text_score, runningScore
class PSENetTrain(DeepvacTrain):
def __init__(self, deepvac_config):
super(PSENetTrain,self).__init__(deepvac_config)
def doFeedData2Device(self):
self.config.sample = self.config.sample.to(self.config.device)
if self.config.target is not None:
self.config.target = [tar.to(self.config.device) for tar in self.config.target]
def doLoss(self):
if not self.config.is_train:
return
self.config.loss = self.config.criterion(self.config.output, self.config.target)
def postIter(self):
if self.config.is_train:
return
self.score_text = cal_text_score(self.config.output[:, 0, :, :], self.config.target[0], self.config.target[2], self.running_metric_text)
def preEpoch(self):
if self.config.is_train:
return
self.running_metric_text = runningScore(2)
def postEpoch(self):
if self.config.is_train:
return
self.accuracy = self.score_text['Mean Acc']
LOG.logI('Test accuray: {:.4f}'.format(self.accuracy))
if __name__ == '__main__':
from config import config as deepvac_config
Pse = PSENetTrain(deepvac_config)
Pse()