混淆矩阵
混淆矩阵(Confusion Matrix)是机器学习中用来总结分类模型预测结果的一个分析表,是模式识别领域中的一种常用的表达形式。它以矩阵的形式描绘样本数据的真实属性和分类预测结果类型之间的关系,是用来评价分类器性能的一种常用方法。
我们可以通过一个简单的例子来直观理解混淆矩阵
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
#!/usr/bin/python3.5
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams[ 'font.sans-serif' ] = [ 'FangSong' ] #可显示中文字符
plt.rcParams[ 'axes.unicode_minus' ] = False
classes = [ 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' ]
confusion_matrix = np.array([( 99 , 1 , 2 , 2 , 0 , 0 , 6 ),( 1 , 98 , 7 , 6 , 2 , 1 , 1 ),( 0 , 0 , 86 , 0 , 0 , 2 , 0 ),( 0 , 0 , 0 , 86 , 1 , 0 , 0 ),( 0 , 0 , 0 , 1 , 94 , 1 , 0 ),( 0 , 1 , 5 , 1 , 0 , 96 , 8 ),( 0 , 0 , 0 , 4 , 3 , 0 , 85 )],dtype = np.float64)
plt.imshow(confusion_matrix, interpolation = 'nearest' , cmap = plt.cm.Oranges) #按照像素显示出矩阵
plt.title( '混淆矩阵' )
plt.colorbar()
tick_marks = np.arange( len (classes))
plt.xticks(tick_marks, classes, rotation = - 45 )
plt.yticks(tick_marks, classes)
thresh = confusion_matrix. max () / 2.
#iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
#ij配对,遍历矩阵迭代器
iters = np.reshape([[[i,j] for j in range ( 7 )] for i in range ( 7 )],(confusion_matrix.size, 2 ))
for i, j in iters:
plt.text(j, i, format (confusion_matrix[i, j]),fontsize = 7 ) #显示对应的数字
plt.ylabel( '真实类别' )
plt.xlabel( '预测类别' )
plt.tight_layout()
plt.show()
|
正确率曲线
1
2
3
4
5
6
7
8
9
10
|
fig ,ax = plt.subplots()
plt.plot(np.arange(iterations), fig_acc, 'b' )
plt.plot(np.arange(iterations), fig_realacc, 'r' )
ax.set_xlabel( '迭代次数' )
ax.set_ylabel( '正确率(%)' )
labels = [ "训练正确率" , "测试正确率" ]
# labels = [l.get_label() for l in lns]
plt.legend( labels, loc = 7 )
plt.show()
|
总结
到此这篇关于matplotlib画混淆矩阵与正确率曲线的文章就介绍到这了,更多相关matplotlib画混淆矩阵内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!
原文链接:https://blog.csdn.net/yuan0401yu/article/details/88730555