Skip to content

visualization

metrics

plot_confusion_matrix(y_true, y_pred, display_labels=None, include_values=True, title='Confusion Matrix', cmap=None)

Plots confusion matrix

Parameters:

Name Type Description Default
y_true Union[numpy.ndarray, List]

Sequence of ground truth value

required
y_pred Union[numpy.ndarray, List]

Sequence of predictions

required
display_labels

Labels to display on the plot

None
include_values bool

Whether to keep values in the Matrix cells

True
title str

Title of Matrix image

'Confusion Matrix'
cmap str

color maps

None
Source code in chitra/visualization/metrics.py
def plot_confusion_matrix(
    y_true: Union[np.ndarray, List],
    y_pred: Union[np.ndarray, List],
    display_labels=None,
    include_values: bool = True,
    title: str = "Confusion Matrix",
    cmap: str = None,
):
    """Plots confusion matrix
    Args:
        y_true: Sequence of ground truth value
        y_pred: Sequence of predictions
        display_labels: Labels to display on the plot
        include_values: Whether to keep values in the Matrix cells
        title: Title of Matrix image
        cmap: color maps
    """
    if detect_multilabel(y_true):
        logger.warning("You might want to use multi-label version!")

    if display_labels is None:
        display_labels = np.unique(y_true)

    n_classes = len(display_labels)
    tick_marks = np.arange(n_classes)

    if cmap is None:
        cmap = plt.get_cmap("Blues")

    cm = confusion_matrix(y_true, y_pred)
    accuracy = cm_accuracy(cm)
    error = 1 - accuracy

    plt.imshow(cm, cmap=cmap)

    if include_values:
        for i, j in product(range(n_classes), range(n_classes)):
            plt.text(i, j, "{:,}".format(cm[i, j]))

    plt.xticks(tick_marks, display_labels, rotation=45)
    plt.yticks(tick_marks, display_labels)
    plt.title(title)
    plt.xlabel(f"Predicted Label\nAccuracy={accuracy:0.4f}; Error={error:0.4f}")
    plt.ylabel("True Label")

    plt.show()

Last update: November 27, 2021