实习点滴(11)--TensorFlow快速计算“多分类问题”的混淆矩阵以及精确率、召回率、F1值、准确率

时间:2022-12-08 09:52:24

        在机器学习中,我们会利用一些指标(混淆矩阵、精确率、召回率、F1值、准确率)来判断我们模型的好坏,从而改进优化模型。下面介绍如何在TensorFlow下快速计算这些指标。

        1、混淆矩阵

        confusion_matrix = tf.contrib.metrics.confusion_matrix(labels_pred_all, labels_all, num_classes=None, dtype=tf.int32, name=None, weights=None)
confusion_matrix = sess.run(confusion_matrix)

        因为第一步所计算出来的混淆矩阵是一个Tensor,所以需要进行转换。

        具体api详解:

        https://haosdent.gitbooks.io/tensorflow-document/content/api_docs/python/contrib.metrics.html#confusion_matrix

        值得注意的是:所计算出来的混淆矩阵,列是真实值(也就是期望值),行是预测值

        2、四大指标:

        有了混淆矩阵,计算四大指标就好办了。

        accu = [0,0,0,0,0]
column = [0,0,0,0,0]
line = [0,0,0,0,0]
accuracy = 0
recall = 0
precision = 0
for i in range(0,5):
accu[i] = confusion_matrix[i][i]
for i in range(0,5):
for j in range(0,5):
column[i]+=confusion_matrix[j][i]
for i in range(0,5):
for j in range(0,5):
line[i]+=confusion_matrix[i][j]
for i in range(0,5):
accuracy += float(accu[i])/len_labels_all
for i in range(0,5):
if column[i] != 0:
recall+=float(accu[i])/column[i]
recall = recall / 5
for i in range(0,5):
if line[i] != 0:
precision+=float(accu[i])/line[i]
precision = precision / 5
f1_score = (2 * (precision * recall)) / (precision + recall)