forked from AmadeusBugProject/artifact_detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRQ1_plot_learning_curve.py
70 lines (53 loc) · 2.44 KB
/
RQ1_plot_learning_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pandas
from matplotlib import pyplot as plt
from artifact_detection_model.utils.Logger import Logger
from datasets.constants import LANGUAGES
from file_anchor import root_dir
log = Logger()
OUT_PATH = root_dir() + 'evaluation/out/learning_curve/'
language_labels = {
'cpp': 'C++',
'java': 'Java',
'javascript': 'JavaScript',
'php': 'PHP',
'python': 'Python',
}
def main():
fig = plt.figure(figsize=(8, 10))
axes = []
axes.append(fig.add_subplot(3, 2, 1))
axes.append(fig.add_subplot(3, 2, 2, sharex=axes[0], sharey=axes[0]))
axes.append(fig.add_subplot(3, 2, 3, sharex=axes[0], sharey=axes[0]))
axes.append(fig.add_subplot(3, 2, 4, sharex=axes[0], sharey=axes[0]))
axes.append(fig.add_subplot(3, 2, 5, sharex=axes[0], sharey=axes[0]))
axes.append(fig.add_subplot(3, 2, 6, sharex=axes[0]))
for i, lang in enumerate(LANGUAGES):
print(lang)
df = pandas.read_csv(OUT_PATH + lang + '_artifact_detection_summary.csv')
df = df[df['train_samples'] <= 800000]
df['train_samples'] = df['train_samples']/ 10**3
ax = axes[i]
gb = df.groupby(by='train_samples')
plot_mean_and_fill_std(ax, gb, 'roc-auc_' + lang + '_researcher_1', 'g', 'Researcher 1 validation set', style='o-')
plot_mean_and_fill_std(ax, gb, 'roc-auc_' + lang + '_researcher_2', 'b', 'Researcher 2 validation set', style='v-')
ax.title.set_text(language_labels[lang])
colors = ['red', 'purple', 'violet', 'k', 'c']
styles = ['p-', '*-', 'v-', 'D-', 'X-']
plot_mean_and_fill_std(axes[5], gb, 'model_size', colors[i], language_labels[lang], style=styles[i])
axes[0].set_ylabel('ROC-AUC')
axes[2].set_ylabel('ROC-AUC')
axes[4].set_ylabel('ROC-AUC')
axes[4].set_xlabel('Training set size (10^3 lines)')
axes[5].set_xlabel('Training set size (10^3 lines)')
axes[0].legend(loc='lower right')
axes[5].set_ylabel('Model size (MiB)')
axes[5].legend(loc='lower right')
plt.tight_layout()
plt.savefig(OUT_PATH + 'roc-auc_validation_set_learning_curve.pdf')
def plot_mean_and_fill_std(axes, gb, metric, color, label, style='o-'):
axes.fill_between(gb.mean().index, gb.mean()[metric] - gb.std()[metric],
gb.mean()[metric] + gb.std()[metric], alpha=0.1,
color=color)
axes.plot(gb.mean().index, gb.mean()[metric], style, color=color, label=label)
if __name__ == "__main__":
main()