-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_robustification.py
95 lines (74 loc) · 2.85 KB
/
plot_robustification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from sklearn.datasets import make_moons, make_circles
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import check_random_state
from groot.toolbox import Model
from groot.visualization import plot_estimator
from robust_relabeling import relabel_model
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
def make_tree(n_samples, random_state=None):
random_state = check_random_state(random_state)
def ground_truth_tree(sample):
if sample[0] <= 0.5:
if sample[1] <= 0.5:
return 0
else:
return 1
else:
if sample[1] <= 0.5:
return 1
else:
return 0
X = random_state.uniform(size=(n_samples, 2))
y = np.apply_along_axis(ground_truth_tree, 1, X)
return X, y
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="out/")
parser.add_argument("--epsilon", type=float, default=0.04)
args = parser.parse_args()
sns.set_theme(style="whitegrid", palette="colorblind")
for name, (X, y), (X_test, y_test) in zip(
("moons", "circles", "tree"),
(
make_moons(n_samples=200, noise=0.3, random_state=1),
make_circles(n_samples=200, factor=0.5, noise=0.2, random_state=1),
make_tree(n_samples=200, random_state=1),
),
(
make_moons(n_samples=200, noise=0.3, random_state=2),
make_circles(n_samples=200, factor=0.5, noise=0.2, random_state=2),
make_tree(n_samples=200, random_state=2),
),
):
X = MinMaxScaler().fit_transform(X)
random_state = check_random_state(1)
y = np.where(random_state.rand(len(y)) > 0.95, 1 - y, y)
for model_name, classifier in zip(
("tree", "forest", "boosting"),
(
DecisionTreeClassifier(max_depth=5, min_samples_leaf=3, random_state=1),
RandomForestClassifier(n_estimators=100, random_state=1),
GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=1),
),
):
classifier.fit(X, y)
model = Model.from_sklearn(classifier)
plot_estimator(X, y, model, steps=500)
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"{args.output_dir}{name}_{model_name}.png", dpi=200)
plt.savefig(f"{args.output_dir}{name}_{model_name}.pdf")
plt.close()
relabel_model(model, X, y, args.epsilon)
plot_estimator(X, y, model, steps=500)
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(f"{args.output_dir}{name}_{model_name}_relabeled.png", dpi=200)
plt.savefig(f"{args.output_dir}{name}_{model_name}_relabeled.pdf")
plt.close()