-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
30 lines (24 loc) · 871 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
from deepvac import LOG, DeepvacTrain
class MobileNetv3Train(DeepvacTrain):
def postIter(self):
if self.config.is_train:
return
self.prediction = np.argmax(self.config.output.cpu().data,axis=1)
for pred, target in zip(self.prediction, self.config.target.cpu()):
print(pred,target)
if pred == target:
self.n_correct += 1
def preEpoch(self):
if self.config.is_train:
return
self.n_correct = 0
def postEpoch(self):
if self.config.is_train:
return
self.accuracy = self.n_correct / self.config.val_dataset.__len__()
LOG.logI('Test accuray: {:.4f}'.format(self.accuracy))
if __name__ == '__main__':
from config import config as deepvac_config
cls = MobileNetv3Train(deepvac_config)
cls()