tf.argmax()函數作用


tf.argmax()函數原型:

def argmax(input,
           axis=None,
           name=None,
           dimension=None,
           output_type=dtypes.int64)

作用是返回每列/行的最大值的索引。

input是一個張量,

axis是0或1,0返回各列最大值索引,1返回各行最大值索引。

其他3個參數不常用,常用寫法是 a = tf.argmax(tensor, 1)。

 

import tensorflow as tf
sess = tf.InteractiveSession()

a = tf.constant([[12, 3, 9],
                 [3, 6, 13]]) 

b_1 = tf.argmax(a, 0)   # 返回ndarry,元素是每列的最大值索引
b_2 = tf.argmax(a, 1)

print(b_1)   # >>array([0, 1, 1], dtype=int64)
print(b_2)   # >>array([0, 2], dtype=int64)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM