Skip to content

Commit cd163d9

Browse files
committed
Add roc_curve example
1 parent 6d41ae6 commit cd163d9

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,5 @@ venv.bak/
102102

103103
# mypy
104104
.mypy_cache/
105+
.idea/*
106+
.vscode/*

metrics.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import sklearn.metrics as skm
22
import numpy as np
33

4-
54
# #binary problem
65

76
y_true = np.array([0, 0, 1, 0, 0, 1, 0, 0, 1, 1], dtype=float)
8-
#y_pred = [0, 0, 1, 1, 0, 1, 0, 1, 0, 0]
7+
# y_pred = [0, 0, 1, 1, 0, 1, 0, 1, 0, 0]
98
y_pred = np.array([0.01, 0.12, 0.89, .99, .05, .76, .14, .87, .44, .32])
109
y_pred = y_pred > 0.5
1110

12-
1311
y_pred = np.array([0, 0, 1, 1])
1412
y_true = np.array([0, 0, 1, 0])
1513

@@ -49,7 +47,6 @@
4947
print(skm.classification_report(y_true, y_pred,
5048
target_names=['label1==1', 'label2==1']))
5149

52-
5350
# 3 label binary problem
5451
y_true = np.array([
5552
[0, 0, 0], [1, 1, 1], [0, 0, 0], [1, 1, 1]
@@ -64,6 +61,7 @@
6461
target_names=['label1==1', 'label2==1', 'label3==1']))
6562

6663

64+
6765
# 2 label multiclass problem
6866
# !!!!! multi label multi output not supported
6967
# y_true = np.array([

roc_curve.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
from sklearn import metrics
3+
import matplotlib.pyplot as plt
4+
5+
6+
y = np.array([0, 0, 1, 1,1])
7+
scores = np.array([0.1, 0.4, 0.35, 0.6, 0.8])
8+
fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=1)
9+
10+
print(f'\nfpr={fpr}\ntpr={tpr}\ntre={thresholds}')
11+
12+
plt.plot(fpr, tpr, lw=1)
13+
plt.scatter(fpr,tpr)
14+
15+
plt.xlim([-0.05, 1.05])
16+
plt.ylim([-0.05, 1.05])
17+
plt.xlabel('False Positive Rate')
18+
plt.ylabel('True Positive Rate')
19+
plt.title('Receiver operating characteristic example')
20+
#plt.legend(loc="lower right")
21+
plt.show()

0 commit comments

Comments
 (0)