tf.argmax()函数作用


tf.argmax()函数原型:

def argmax(input,
           axis=None,
           name=None,
           dimension=None,
           output_type=dtypes.int64)

作用是返回每列/行的最大值的索引。

input是一个张量,

axis是0或1,0返回各列最大值索引,1返回各行最大值索引。

其他3个参数不常用,常用写法是 a = tf.argmax(tensor, 1)。

 

import tensorflow as tf
sess = tf.InteractiveSession()

a = tf.constant([[12, 3, 9],
                 [3, 6, 13]]) 

b_1 = tf.argmax(a, 0)   # 返回ndarry,元素是每列的最大值索引
b_2 = tf.argmax(a, 1)

print(b_1)   # >>array([0, 1, 1], dtype=int64)
print(b_2)   # >>array([0, 2], dtype=int64)

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM