Skip to content

Commit 4cadb2a

Browse files
Merge pull request #33 from BioGeMT/stef/clean_evalmetrics
removed confusion matrix eval metric as it is not calculated correctly
2 parents ec6bccb + c7940c1 commit 4cadb2a

2 files changed

Lines changed: 9 additions & 56 deletions

File tree

code/eval_metrics/example_output/AGO2_CLASH_Hejret2023_1_cm.tsv

Lines changed: 0 additions & 12 deletions
This file was deleted.

code/eval_metrics/get_metric.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from sklearn.metrics import roc_auc_score
55
from sklearn.metrics import auc
66
from sklearn.metrics import average_precision_score
7-
from sklearn.preprocessing import MinMaxScaler
8-
from sklearn.metrics import confusion_matrix
97
import argparse
108
import sys
119

@@ -44,36 +42,8 @@ def get_metric(data, predictors, metric):
4442
avg_p_score = average_precision_score(data['label'], data[predictor])
4543
metric_dict[predictor] = np.round(avg_p_score, 2)
4644

47-
elif metric == 'cm':
48-
if predictor =='TargetScanCnn_McGeary2019' or predictor == 'RNACofold':
49-
# Normalise the predictions to [0, 1]
50-
scaler = MinMaxScaler()
51-
y_pred_reshaped = data[predictor].values.reshape(-1, 1)
52-
y_pred_normalised = scaler.fit_transform(y_pred_reshaped)
53-
y_pred = y_pred_normalised.flatten()
54-
else:
55-
y_pred = data[predictor].tolist()
56-
57-
y_true = data['label'].tolist()
58-
59-
if predictor.startswith('Seed'):
60-
y_pred_bin = y_pred
61-
else:
62-
# Compute the binary predictions
63-
precision, recall, thresholds = precision_recall_curve(y_true, y_pred) # some recall values are 0
64-
np.seterr(invalid='ignore') # ignore division by zero warning
65-
fscore = (2 * precision * recall) / (precision + recall)
66-
fscore_max_index = np.argmax(fscore) # locate the index of the largest f score
67-
threshold = thresholds[fscore_max_index]
68-
y_pred_bin = [1 if p >= threshold else 0 for p in y_pred]
69-
70-
# Compute and extract TP, TN, FP, FN
71-
tn, fp, fn, tp = confusion_matrix(y_true, y_pred_bin).ravel()
72-
73-
metric_dict[predictor] = [int(tn), int(fp), int(fn), int(tp)]
74-
7545
else:
76-
raise ValueError(f"Invalid metric: {metric}. Please choose one of 'auc-pr', 'auc-roc', 'avg_p_score', or 'cm'.")
46+
raise ValueError(f"Invalid metric: {metric}. Please choose one of 'auc-pr', 'auc-roc', or 'avg_p_score'.")
7747

7848
return metric_dict
7949

@@ -84,7 +54,7 @@ def main():
8454
parser = argparse.ArgumentParser(description="Evaluate predictors.")
8555
parser.add_argument('--ifile', help="Input file containing the prediction scores in TSV format (default: STDIN)", default=None)
8656
parser.add_argument('--predictors', help="List of predictor names (default: all)", default=None)
87-
parser.add_argument('--metric', help="Evaluation metric to compute; auc_pr, auc_roc, avg_p_score, or cm.", default=None)
57+
parser.add_argument('--metric', help="Evaluation metric to compute; auc_pr, auc_roc, or avg_p_score.", default=None)
8858
parser.add_argument('--ofile', help="Output file (default: STDOUT)", default=None)
8959
args = parser.parse_args()
9060

@@ -115,24 +85,19 @@ def main():
11585

11686
# if metric is none, raise an error
11787
if args.metric is None:
118-
raise ValueError(f"Missing metric. Please choose one of 'auc_pr', 'auc_roc', 'avg_p_score', or 'cm'.")
88+
raise ValueError(f"Missing metric. Please choose one of 'auc_pr', 'auc_roc', or 'avg_p_score'.")
11989

12090
# get the metrics
12191
metric = get_metric(data, args.predictors, args.metric)
12292

12393
# write the results to the output file
12494
with open(args.ofile, 'w') as ofile:
125-
if args.metric == 'cm':
126-
ofile.write(f"Tool\tTN\tFP\tFN\tTP\n")
127-
for predictor in args.predictors:
128-
ofile.write(f"{predictor}\t{metric[predictor][0]}\t{metric[predictor][1]}\t{metric[predictor][2]}\t{metric[predictor][3]}\n")
129-
else:
130-
ofile.write(f"Tool\t{args.metric}\n")
131-
for predictor in args.predictors:
132-
if predictor.startswith('Seed'):
133-
continue
134-
else:
135-
ofile.write(f"{predictor}\t{metric[predictor]}\n")
95+
ofile.write(f"Tool\t{args.metric}\n")
96+
for predictor in args.predictors:
97+
if predictor.startswith('Seed'):
98+
continue
99+
else:
100+
ofile.write(f"{predictor}\t{metric[predictor]}\n")
136101

137102
if __name__ == "__main__":
138103
main()

0 commit comments

Comments
 (0)