symmetrized confusion matrix #624
yanivboker
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I made a symmetrized confusion matrix so it will sum up the confusion from pairs of labels (shirt->t-shirt and t-shirt->shirt)
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix
2. Setup confusion matrix instance and compare predictions to targets
confmat = ConfusionMatrix(num_classes=len(class_names), task='multiclass')
confmat_tensor = confmat(preds=y_pred_tensor,
target=test_data.targets)
3. Plot the confusion matrix
fig, ax = plot_confusion_matrix(
conf_mat=confmat_tensor.numpy(), # matplotlib likes working with NumPy
class_names=class_names, # turn the row and column labels into class names
figsize=(10, 7)
);
Function to symmetrize a confusion matrix
def symmetrize_confusion_matrix(matrix):
return matrix + matrix.T - np.diag(matrix.diagonal())
Function to plot the symmetrized confusion matrix
def plot_symmetrized_confusion_matrix(matrix, class_names):
plt.figure(figsize=(10, 10))
plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Symmetrized Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
Class names for Fashion-MNIST
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Create symmetrized confusion matrix and plot it
symmetrized_conf_matrix = symmetrize_confusion_matrix(confmat_tensor)
plot_symmetrized_confusion_matrix(symmetrized_conf_matrix, class_names)
Beta Was this translation helpful? Give feedback.
All reactions