|
import itertools |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei'] |
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
|
|
|
|
def plot_confusion_matrix(cm, classes, |
|
normalize=False, |
|
title='Confusion matrix', |
|
cmap=plt.cm.Blues): |
|
""" |
|
This function prints and plots the confusion matrix. |
|
Normalization can be applied by setting `normalize=True`. |
|
""" |
|
if normalize: |
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
|
print("Normalized confusion matrix") |
|
else: |
|
print('Confusion matrix, without normalization') |
|
|
|
print(cm) |
|
|
|
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) |
|
plt.title(title) |
|
plt.colorbar() |
|
tick_marks = np.arange(len(classes)) |
|
plt.xticks(tick_marks, classes, fontsize=16) |
|
plt.yticks(tick_marks, classes, fontsize=16) |
|
|
|
fmt = '.2f' if normalize else 'd' |
|
thresh = cm.max() / 2. |
|
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
|
plt.text(j, i, format(cm[i, j], fmt), |
|
horizontalalignment="center", |
|
color="white" if cm[i, j] > thresh else "black") |
|
|
|
plt.tight_layout() |
|
plt.ylabel('True Label',fontsize=12) |
|
plt.xlabel('Predicted Label',fontsize=12) |
|
plt.show() |
|
|
|
|
|
|
|
cnf_matrix = np.array([[ 299 , 6 , 5 , 3 , 1 , 4, 11], |
|
[ 9, 51 , 0, 2 , 8, 2 , 2], |
|
[ 2 , 1 ,120 , 6 ,13 , 9 , 9], |
|
[ 5 , 1 , 7 ,1148 , 2 , 4 , 18], |
|
[ 0 , 0 , 9 , 4 ,442 , 1 , 22], |
|
[ 2 ,0 , 7 , 3 , 0 ,145 , 5], |
|
[ 10 ,0, 6 ,11, 29 , 0, 624]]) |
|
|
|
class_names = ["SU", 'FE', 'AN', 'HA', 'SA', 'DI', 'NE'] |
|
|
|
|
|
plt.figure(dpi=200) |
|
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, |
|
title=None) |
|
|