Skip to content

Commit

Permalink
Update example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Bougeant committed Oct 18, 2024
1 parent 6d624a3 commit 81420e9
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions examples/binary_classifier.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict

from ml_inspector import plot_precision_recall_curves, plot_roc_curves, plot_gain_curves
from ml_inspector import *
```

## Load classification dataset
Expand All @@ -45,7 +45,8 @@ y = pd.Series(dataset["target"])
```

```python
class_names = dataset["target_names"]
class_names = {i: str(c) for i, c in enumerate(dataset["target_names"])}
class_names
```

## Train binary classification model
Expand All @@ -60,6 +61,13 @@ rf.fit(X, y)

## Make predictions

```python
y_pred = {
"Training": rf.predict(X),
"Cross-Validation": cross_val_predict(rf, X, y, cv=5)
}
```

```python
y_prob = {
"Training": rf.predict_proba(X),
Expand All @@ -84,3 +92,37 @@ plot_precision_recall_curves(y, y_prob, class_names, decision_threshold=0.5)
```python
plot_gain_curves(y, y_prob, class_names, decision_threshold=0.5)
```

## Confusion matrix

```python
plot_confusion_matrix(y, y_pred["Cross-Validation"], class_names)
```

## Probability distributions

```python
plot_classification_predictions(y, y_prob["Cross-Validation"], class_names, decision_threshold=0.5, points="all")
```

## Class calibration curves

```python
plot_calibration_curves(y, y_prob["Cross-Validation"], class_names, n_bins=10)
```

## Learning curve

```python
plot_learning_curves(rf, X, y, scoring="roc_auc")
```

## Partial dependence

```python
plot_partial_dependence(rf, X, "worst concavity")
```

```python

```

0 comments on commit 81420e9

Please sign in to comment.