使用tensorflow计算f-measure和召回率等内容的,需要安装一个sklearn
win7 64 下 你只需要输入 pip3 install sklearn 即可
下边是例子:
来自于* 感谢这个网站吧:
原帖子https://*.com/questions/35365007/tensorflow-precision-recall-f1-score-and-confusion-matrix
from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred)
pred = multilayer_perceptron(x, weights, biases) correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) for epoch in xrange(150): for i in xrange(total_batch): train_step.run(feed_dict = {x: train_arrays, y: train_labels}) avg_cost += sess.run(cost, feed_dict={x: train_arrays, y: train_labels})/total_batch if epoch % display_step == 0: print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost) #metrics y_p = tf.argmax(pred, 1) val_accuracy, y_pred = sess.run([accuracy, y_p], feed_dict={x:test_arrays, y:test_label}) print("validation accuracy:", val_accuracy) y_true = np.argmax(test_label,1) print("Precision", sk.metrics.precision_score(y_true, y_pred)) print( "Recall", sk.metrics.recall_score(y_true, y_pred)) print( "f1_score", sk.metrics.f1_score(y_true, y_pred)) print( "confusion_matrix") print( sk.metrics.confusion_matrix(y_true, y_pred)) fpr, tpr, tresholds = sk.metrics.roc_curve(y_true, y_pred)
keras 1.2版本也有这些值
首先在compile加入这些参数
model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy', 'f1score', 'precision', 'recall'])
#然后用plt描绘出来 他们对应的result中存的key分别是:precision,val_precision, recall, val_recall ,acc, val_acc,#loss, val_loss
plt.figure()
fig = plt.gcf()fig.set_size_inches(18.5, 10.5)
plt.plot(result.epoch,result.history['fmeasure'],label="fmeasure")
plt.plot(result.epoch,result.history['val_fmeasure'],label="val_fmeasure")
plt.scatter(result.epoch,result.history['fmeasure'],marker='*')
plt.scatter(result.epoch,result.history['val_fmeasure'])
plt.title('Fmeasure')
plt.ylabel('fmeasure')
plt.xlabel('epoch \ times')
plt.legend(loc='under right')
plt.show()