from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import numpy as np def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): cm = confusion_matrix(y_true, y_pred) fig, ax = plt.subplots(figsize=(6, 6)) im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) ax.figure.colorbar(im, ax=ax) num_classes = cm.shape[0] tick_labels = classes[:num_classes] ax.set(xticks=np.arange(num_classes), yticks=np.arange(num_classes), xticklabels=tick_labels, yticklabels=tick_labels, ylabel='True label', xlabel='Predicted label') thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], 'd'), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") fig.tight_layout() writer.add_figure("Confusion Matrix", fig, epoch)