Skip to content

Commit

Permalink
Update ROC curve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Bougeant committed Sep 8, 2024
1 parent e1b4c00 commit 64802c2
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/test_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_plot_roc_curves_binary(self, binary_predictions):
fig = roc_curves.plot_roc_curves(y, y_prob, class_names=class_names)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 3
assert "Class 1 (Model 1): AUC=" in fig.data[0]["name"]
assert "Class 1 (Model 2): AUC=" in fig.data[1]["name"]
assert fig.data[0]["name"] == "Class 1 (Model 1): AUC=1.00"
assert fig.data[1]["name"] == "Class 1 (Model 2): AUC=0.75"
assert fig.data[2]["name"] == "Random decision: AUC=0.50"

def test_plot_roc_curves_multi_class(self, multiclass_predictions):
Expand All @@ -26,12 +26,13 @@ def test_plot_roc_curves_multi_class(self, multiclass_predictions):
fig = roc_curves.plot_roc_curves(y, y_prob, class_names=class_names)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 11
assert "Class 0 (Training): AUC=" in fig.data[0]["name"]
assert "Class 1 (Training): AUC=" in fig.data[1]["name"]
assert "Class 2 (Training): AUC=" in fig.data[2]["name"]
assert "Class 3 (Training): AUC=" in fig.data[3]["name"]
assert "Micro-average ROC curve (Training): AUC=" in fig.data[4]["name"]
assert "Class 0 (Test): AUC=" in fig.data[5]["name"]
assert fig.data[0]["name"] == "Class 0 (Training): AUC=1.00"
assert fig.data[1]["name"] == "Class 1 (Training): AUC=1.00"
assert fig.data[2]["name"] == "Class 2 (Training): AUC=1.00"
assert fig.data[3]["name"] == "Class 3 (Training): AUC=1.00"
assert fig.data[4]["name"] == "Micro-average ROC curve (Training): AUC=0.96"
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.81"
assert fig.data[-1]["name"] == "Random decision: AUC=0.50"

def test_plot_roc_curves_error_single_class(self, binary_predictions):
y, y_prob_1, y_prob_2 = binary_predictions
Expand Down Expand Up @@ -64,5 +65,5 @@ def test_plot_roc_curves_with_array(self, binary_predictions):
fig = roc_curves.plot_roc_curves(y, y_prob_1, class_names=class_names)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2
assert "Class 1 (Predictions): AUC=" in fig.data[0]["name"]
assert fig.data[0]["name"] == "Class 1 (Predictions): AUC=1.00"
assert fig.data[1]["name"] == "Random decision: AUC=0.50"

0 comments on commit 64802c2

Please sign in to comment.