-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
22 lines (19 loc) · 881 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
from deepvac import Deepvac, ClassifierReport, LOG
class MobileNetv3Test(Deepvac):
def doTest(self):
cls_report = ClassifierReport(ds_name="cls_dataset",total_num=self.config.test_loader.__len__(), cls_num=self.config.cls_num, threshold=0.99)
for i, (inp, labels) in enumerate(self.config.test_loader):
inp = inp.to(self.config.device)
prediction = self.config.net(inp).cpu()
prediction = np.argmax(prediction.data, axis=1)
for gt, pred in zip(np.array(labels), np.array(prediction)):
cls_report.add(gt,pred)
if i%100 ==0:
LOG.logI('Process {} samples. '.format(i))
cls_report()
self.config.sample = inp
if __name__ == '__main__':
from config import config as deepvac_config
cls = MobileNetv3Test(deepvac_config)
cls()