44from sklearn .metrics import roc_auc_score
55from sklearn .metrics import auc
66from sklearn .metrics import average_precision_score
7- from sklearn .preprocessing import MinMaxScaler
8- from sklearn .metrics import confusion_matrix
97import argparse
108import sys
119
@@ -44,36 +42,8 @@ def get_metric(data, predictors, metric):
4442 avg_p_score = average_precision_score (data ['label' ], data [predictor ])
4543 metric_dict [predictor ] = np .round (avg_p_score , 2 )
4644
47- elif metric == 'cm' :
48- if predictor == 'TargetScanCnn_McGeary2019' or predictor == 'RNACofold' :
49- # Normalise the predictions to [0, 1]
50- scaler = MinMaxScaler ()
51- y_pred_reshaped = data [predictor ].values .reshape (- 1 , 1 )
52- y_pred_normalised = scaler .fit_transform (y_pred_reshaped )
53- y_pred = y_pred_normalised .flatten ()
54- else :
55- y_pred = data [predictor ].tolist ()
56-
57- y_true = data ['label' ].tolist ()
58-
59- if predictor .startswith ('Seed' ):
60- y_pred_bin = y_pred
61- else :
62- # Compute the binary predictions
63- precision , recall , thresholds = precision_recall_curve (y_true , y_pred ) # some recall values are 0
64- np .seterr (invalid = 'ignore' ) # ignore division by zero warning
65- fscore = (2 * precision * recall ) / (precision + recall )
66- fscore_max_index = np .argmax (fscore ) # locate the index of the largest f score
67- threshold = thresholds [fscore_max_index ]
68- y_pred_bin = [1 if p >= threshold else 0 for p in y_pred ]
69-
70- # Compute and extract TP, TN, FP, FN
71- tn , fp , fn , tp = confusion_matrix (y_true , y_pred_bin ).ravel ()
72-
73- metric_dict [predictor ] = [int (tn ), int (fp ), int (fn ), int (tp )]
74-
7545 else :
76- raise ValueError (f"Invalid metric: { metric } . Please choose one of 'auc-pr', 'auc-roc', 'avg_p_score', or 'cm '." )
46+ raise ValueError (f"Invalid metric: { metric } . Please choose one of 'auc-pr', 'auc-roc', or 'avg_p_score '." )
7747
7848 return metric_dict
7949
@@ -84,7 +54,7 @@ def main():
8454 parser = argparse .ArgumentParser (description = "Evaluate predictors." )
8555 parser .add_argument ('--ifile' , help = "Input file containing the prediction scores in TSV format (default: STDIN)" , default = None )
8656 parser .add_argument ('--predictors' , help = "List of predictor names (default: all)" , default = None )
87- parser .add_argument ('--metric' , help = "Evaluation metric to compute; auc_pr, auc_roc, avg_p_score, or cm ." , default = None )
57+ parser .add_argument ('--metric' , help = "Evaluation metric to compute; auc_pr, auc_roc, or avg_p_score ." , default = None )
8858 parser .add_argument ('--ofile' , help = "Output file (default: STDOUT)" , default = None )
8959 args = parser .parse_args ()
9060
@@ -115,24 +85,19 @@ def main():
11585
11686 # if metric is none, raise an error
11787 if args .metric is None :
118- raise ValueError (f"Missing metric. Please choose one of 'auc_pr', 'auc_roc', 'avg_p_score', or 'cm '." )
88+ raise ValueError (f"Missing metric. Please choose one of 'auc_pr', 'auc_roc', or 'avg_p_score '." )
11989
12090 # get the metrics
12191 metric = get_metric (data , args .predictors , args .metric )
12292
12393 # write the results to the output file
12494 with open (args .ofile , 'w' ) as ofile :
125- if args .metric == 'cm' :
126- ofile .write (f"Tool\t TN\t FP\t FN\t TP\n " )
127- for predictor in args .predictors :
128- ofile .write (f"{ predictor } \t { metric [predictor ][0 ]} \t { metric [predictor ][1 ]} \t { metric [predictor ][2 ]} \t { metric [predictor ][3 ]} \n " )
129- else :
130- ofile .write (f"Tool\t { args .metric } \n " )
131- for predictor in args .predictors :
132- if predictor .startswith ('Seed' ):
133- continue
134- else :
135- ofile .write (f"{ predictor } \t { metric [predictor ]} \n " )
95+ ofile .write (f"Tool\t { args .metric } \n " )
96+ for predictor in args .predictors :
97+ if predictor .startswith ('Seed' ):
98+ continue
99+ else :
100+ ofile .write (f"{ predictor } \t { metric [predictor ]} \n " )
136101
137102if __name__ == "__main__" :
138103 main ()
0 commit comments