tensorflow中tensor,從每行取指定索引元素


實驗有需求,需要對tensor中每一行取一個不同的索引的元素,其中tensor為2維(本文方法適合任意維),因此本文以2維tensor為例。

# 二維tensor
g = tf.constant([[1,2,3,4,5,6,7,8],[9,8,7,6,5,4,3,2]])
# 每一行取的index,在本例中,正確取值為[3, 2],即第一行index=2的元素和第二行index=7的元素
h_index = np.array([2, 7]).reshape(-1, 1)

# 構建一個numpy的arange列表,其長度為tensor的行數
line = np.arange(2).reshape(-1, 1)

# 注意上面兩個numpy數組的格式都是(-1, 1)
# 將h_index和line合並
index = np.hstack((line, h_index))

# 使用tf.gather_nd來取值
result = tf.gather_nd(g, index)

如上即可,返回仍為tensor


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM