TensorFlow函数:tf.nn.in_top_k()

时间:2021-01-03 13:48:02

       tf.nn.in_top_k()函数的参数如下:

in_top_k(predictions, targets, k, name=None)

       predictions:预测的结果,预测矩阵大小为样本数×标注的label类的个数的二维矩阵。

       targets:实际的标签,大小为样本数。

       k:每个样本的预测结果的前k个最大的数里面是否包含targets预测中的标签,一般都是取1,即取预测最大概率的索引与标签对比。

       name:名字。

       假设有10个样本,标注为5类,10个样本实际标签均是第一类,代码如下:

import tensorflow as tf

logits = tf.Variable(tf.truncated_normal(shape=[10,5],stddev=1.0))
labels = tf.constant([0,0,0,0,0,0,0,0,0,0])

top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(logits.eval())
    print(labels.eval())
    print(top_1_op.eval())
    print(top_2_op.eval())

       运行结果如下:

[[-0.01835343 -1.68495178 -0.67901242 -0.20486258 -0.22725371]
 [ 1.84425163 -1.25509632  0.07132829 -1.81082523 -0.44123012]
 [-0.4354656   0.1805554   0.81912154  0.04202025 -1.99823892]
 [ 0.53393573  0.91522688 -1.88455033 -0.44571343  0.07805539]
 [ 0.01253182  0.16593859  0.0918197   0.8079409   0.13442524]
 [ 0.08205117 -0.26857412  0.02542082  0.38249066 -0.01555154]
 [-1.02280331  0.18952899  0.49389341  0.58559865  0.80859423]
 [ 0.35019293 -1.17765355  0.66553122  1.91787696  0.5998978 ]
 [ 0.81723028  0.92895705  0.86031818  1.57651412  0.94040418]
 [-0.83766556 -1.75260925  0.13499574 -0.06683849 -0.99427927]]
[0 0 0 0 0 0 0 0 0 0]
[ True  True False False False False False False False False]
[ True  True False  True False  True False False False False]

       top_1_op为True的地方top_2_op一定为True,top_1_op取样本的最大预测概率的索引与实际标签对比,top_2_op取样本的最大和仅次最大的两个预测概率与实际标签对比,如果实际标签在其中则为True,否则为False。其他k的取值可以类推。