tensorflow中f-measure,Precision / Recall / F1 score 以及Confusion matrix的计算

时间:2022-12-07 11:29:59
使用tensorflow计算f-measure和召回率等内容的,需要安装一个sklearnwin7 64 下 你只需要输入 pip3 install sklearn 即可下边是例子:来自于* 感谢这个网站吧:原帖子https://*.com/questions/35365007/tensorflow-precision-recall-f1-score-and-confusion-matrix


from sklearn.metrics import confusion_matrixconfusion_matrix(y_true, y_pred)pred = multilayer_perceptron(x, weights, biases)    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))    with tf.Session() as sess:    init = tf.initialize_all_variables()    sess.run(init)    for epoch in xrange(150):            for i in xrange(total_batch):                    train_step.run(feed_dict = {x: train_arrays, y: train_labels})                    avg_cost += sess.run(cost, feed_dict={x: train_arrays, y: train_labels})/total_batch                     if epoch % display_step == 0:                    print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)    #metrics    y_p = tf.argmax(pred, 1)    val_accuracy, y_pred = sess.run([accuracy, y_p], feed_dict={x:test_arrays, y:test_label})    print("validation accuracy:", val_accuracy)    y_true = np.argmax(test_label,1)    print("Precision", sk.metrics.precision_score(y_true, y_pred))    print( "Recall", sk.metrics.recall_score(y_true, y_pred))    print( "f1_score", sk.metrics.f1_score(y_true, y_pred))    print( "confusion_matrix")    print( sk.metrics.confusion_matrix(y_true, y_pred))    fpr, tpr, tresholds = sk.metrics.roc_curve(y_true, y_pred)

keras 1.2版本也有这些值

首先在compile加入这些参数

model.compile(loss='categorical_crossentropy',

              optimizer='adam',

              metrics=['accuracy', 'f1score', 'precision', 'recall'])

#然后用plt描绘出来  他们对应的result中存的key分别是:precision,val_precision, recall, val_recall ,acc, val_acc,#loss, val_loss

plt.figure()

fig = plt.gcf()
fig.set_size_inches(18.5, 10.5)
plt.plot(result.epoch,result.history['fmeasure'],label="fmeasure")
plt.plot(result.epoch,result.history['val_fmeasure'],label="val_fmeasure")
plt.scatter(result.epoch,result.history['fmeasure'],marker='*')
plt.scatter(result.epoch,result.history['val_fmeasure'])
plt.title('Fmeasure')
plt.ylabel('fmeasure')
plt.xlabel('epoch \ times')
plt.legend(loc='under right')
plt.show()