tf.nn.in_top_k的用法


原型:

tf.nn.in_top_k(predictions, targets, k, name=None)

'''
    predictions: 你的預測結果(一般也就是你的網絡輸出值)大小是預測樣本的數量乘以輸出的維度
    target:      實際樣本類別的標簽,大小是樣本數量的個數
    k:           每個樣本中前K個最大的數里面(序號)是否包含對應target中的值
    
'''
import tensorflow as tf
A = tf.Variable([[0.8, 0.4, 0.5, 0.6],[0.1, 0.9, 0.2, 0.4],[0.1, 0.9, 0.4, 0.2]])
B = tf.Variable([1, 1, 2])
result = tf.nn.in_top_k(A, B, 2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(A))
    print(sess.run(B))
    print(sess.run(result))
#   k=1 [False True False]
# k=2 [False True True]
''' 解釋: k取1的時候: 因為A中第一個元素的最大值為0.8,索引(序號)是0,而B是1,不包含B,所以返回False. A中第二個元素的最大值為0.9,索引(序號)是1,而B是1,包含B,所以返回True. A中第三個元素的最大值為0.9,索引(序號)是1,而B是2,不包含B,所以返回False. k取2的時候: 因為A中前兩個元素的最大值為0.8,0.6,索引(序號)是0,3,而B是1,不包含B,所以返回False. A中前兩個元素的最大值為0.9,0.4,索引(序號)是1,3,而B是1,包含B,所以返回True. A中前兩個元素的最大值為0.9,0.4,索引(序號)是1,2,而B是2,包含B,所以返回True. '''

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM