Spaces:
Running
on
Zero
Running
on
Zero
| from sklearn.metrics import confusion_matrix | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): | |
| cm = confusion_matrix(y_true, y_pred) | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
| ax.figure.colorbar(im, ax=ax) | |
| num_classes = cm.shape[0] | |
| tick_labels = classes[:num_classes] | |
| ax.set(xticks=np.arange(num_classes), | |
| yticks=np.arange(num_classes), | |
| xticklabels=tick_labels, | |
| yticklabels=tick_labels, | |
| ylabel='True label', | |
| xlabel='Predicted label') | |
| thresh = cm.max() / 2. | |
| for i in range(cm.shape[0]): | |
| for j in range(cm.shape[1]): | |
| ax.text(j, i, format(cm[i, j], 'd'), | |
| ha="center", va="center", | |
| color="white" if cm[i, j] > thresh else "black") | |
| fig.tight_layout() | |
| writer.add_figure("Confusion Matrix", fig, epoch) |