-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathcalculate_log.py
113 lines (103 loc) · 4.11 KB
/
calculate_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
## Measure the detection performance - Kibok Lee
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
from scipy import misc
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def get_curve(dir_name, stypes = ['Baseline', 'Gaussian_LDA']):
tp, fp = dict(), dict()
tnr_at_tpr95 = dict()
for stype in stypes:
known = np.loadtxt('{}/confidence_{}_In.txt'.format(dir_name, stype), delimiter='\n')
novel = np.loadtxt('{}/confidence_{}_Out.txt'.format(dir_name, stype), delimiter='\n')
known.sort()
novel.sort()
end = np.max([np.max(known), np.max(novel)])
start = np.min([np.min(known),np.min(novel)])
num_k = known.shape[0]
num_n = novel.shape[0]
tp[stype] = -np.ones([num_k+num_n+1], dtype=int)
fp[stype] = -np.ones([num_k+num_n+1], dtype=int)
tp[stype][0], fp[stype][0] = num_k, num_n
k, n = 0, 0
for l in range(num_k+num_n):
if k == num_k:
tp[stype][l+1:] = tp[stype][l]
fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1)
break
elif n == num_n:
tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1)
fp[stype][l+1:] = fp[stype][l]
break
else:
if novel[n] < known[k]:
n += 1
tp[stype][l+1] = tp[stype][l]
fp[stype][l+1] = fp[stype][l] - 1
else:
k += 1
tp[stype][l+1] = tp[stype][l] - 1
fp[stype][l+1] = fp[stype][l]
tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin()
tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n
return tp, fp, tnr_at_tpr95
def metric(dir_name, stypes = ['Bas', 'Gau'], verbose=False):
tp, fp, tnr_at_tpr95 = get_curve(dir_name, stypes)
results = dict()
mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
if verbose:
print(' ', end='')
for mtype in mtypes:
print(' {mtype:6s}'.format(mtype=mtype), end='')
print('')
for stype in stypes:
if verbose:
print('{stype:5s} '.format(stype=stype), end='')
results[stype] = dict()
# TNR
mtype = 'TNR'
results[stype][mtype] = tnr_at_tpr95[stype]
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# AUROC
mtype = 'AUROC'
tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]])
fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]])
results[stype][mtype] = -np.trapz(1.-fpr, tpr)
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# DTACC
mtype = 'DTACC'
results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max()
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# AUIN
mtype = 'AUIN'
denom = tp[stype]+fp[stype]
denom[denom == 0.] = -1.
pin_ind = np.concatenate([[True], denom > 0., [True]])
pin = np.concatenate([[.5], tp[stype]/denom, [0.]])
results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind])
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
# AUOUT
mtype = 'AUOUT'
denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype]
denom[denom == 0.] = -1.
pout_ind = np.concatenate([[True], denom > 0., [True]])
pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]])
results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind])
if verbose:
print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
print('')
return results