TADBot / FER /models /matrix.py
ryefoxlime's picture
Restructured the project files
0e084be
raw
history blame
1.93 kB
import itertools
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# -*- coding:utf-8 -*-
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, fontsize=16)
plt.yticks(tick_marks, classes, fontsize=16)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True Label',fontsize=12)
plt.xlabel('Predicted Label',fontsize=12)
plt.show()
cnf_matrix = np.array([[ 299 , 6 , 5 , 3 , 1 , 4, 11],
[ 9, 51 , 0, 2 , 8, 2 , 2],
[ 2 , 1 ,120 , 6 ,13 , 9 , 9],
[ 5 , 1 , 7 ,1148 , 2 , 4 , 18],
[ 0 , 0 , 9 , 4 ,442 , 1 , 22],
[ 2 ,0 , 7 , 3 , 0 ,145 , 5],
[ 10 ,0, 6 ,11, 29 , 0, 624]])
class_names = ["SU", 'FE', 'AN', 'HA', 'SA', 'DI', 'NE']
plt.figure(dpi=200)
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title=None)