visualization
metrics
¶
plot_confusion_matrix(y_true, y_pred, display_labels=None, include_values=True, title='Confusion Matrix', cmap=None)
¶
Plots confusion matrix
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y_true |
Union[numpy.ndarray, List] |
Sequence of ground truth value |
required |
y_pred |
Union[numpy.ndarray, List] |
Sequence of predictions |
required |
display_labels |
Labels to display on the plot |
None |
|
include_values |
bool |
Whether to keep values in the Matrix cells |
True |
title |
str |
Title of Matrix image |
'Confusion Matrix' |
cmap |
str |
color maps |
None |
Source code in chitra/visualization/metrics.py
def plot_confusion_matrix(
y_true: Union[np.ndarray, List],
y_pred: Union[np.ndarray, List],
display_labels=None,
include_values: bool = True,
title: str = "Confusion Matrix",
cmap: str = None,
):
"""Plots confusion matrix
Args:
y_true: Sequence of ground truth value
y_pred: Sequence of predictions
display_labels: Labels to display on the plot
include_values: Whether to keep values in the Matrix cells
title: Title of Matrix image
cmap: color maps
"""
if detect_multilabel(y_true):
logger.warning("You might want to use multi-label version!")
if display_labels is None:
display_labels = np.unique(y_true)
n_classes = len(display_labels)
tick_marks = np.arange(n_classes)
if cmap is None:
cmap = plt.get_cmap("Blues")
cm = confusion_matrix(y_true, y_pred)
accuracy = cm_accuracy(cm)
error = 1 - accuracy
plt.imshow(cm, cmap=cmap)
if include_values:
for i, j in product(range(n_classes), range(n_classes)):
plt.text(i, j, "{:,}".format(cm[i, j]))
plt.xticks(tick_marks, display_labels, rotation=45)
plt.yticks(tick_marks, display_labels)
plt.title(title)
plt.xlabel(f"Predicted Label\nAccuracy={accuracy:0.4f}; Error={error:0.4f}")
plt.ylabel("True Label")
plt.show()
Last update:
November 27, 2021