[轉]tensorflow中的gather


原文鏈接
tensorflow中取下標的函數包括:tf.gather , tf.gather_nd 和 tf.batch_gather。

1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)

indices必須是一維張量
主要參數:

  • params:被索引的張量
  • indices:一維索引張量
  • name:返回張量名稱

返回值:通過indices獲取params下標的張量。
例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
tensor_c = tf.Variable([0,0],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather(tensor_a,tensor_b)))
    print(sess.run(tf.gather(tensor_a,tensor_c)))

上個例子tf.gather(tensor_a,tensor_b) 的值為[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值為[[1,2,3],[1,2,3]]

對於tensor_a,其第1個元素為[4,5,6],第2個元素為[7,8,9],第0個元素為[1,2,3],所以以[1,2,0]為索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],同樣的,以[0,0]為索引的值為[[1,2,3],[1,2,3]]。

https://www.tensorflow.org/api_docs/python/tf/gather

2.tf.gather_nd(params,indices,name=None)

功能和參數與tf.gather類似,不同之處在於tf.gather_nd支持多維度索引,即indices可以使多維張量。
例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
    print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值為[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值為[3,7].

對於tensor_a,下標[1,0]的元素為4,下標為[1,1]的元素為5,下標為[1,2]的元素為6,索引[1,0],[1,1],[1,2]]的返回值為[4,5,6],同樣的,索引[[0,2],[2,0]]的返回值為[3,7].

https://www.tensorflow.org/api_docs/python/tf/gather_nd

3.tf.batch_gather(params,indices,name=None)

支持對張量的批量索引,各參數意義見(1)中描述。注意因為是批處理,所以indices要有和params相同的第0個維度。

例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
    print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值為[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值為[1,4,7].

tensor_a的三個元素[1,2,3],[4,5,6],[7,8,9]分別對應索引元素的第一,第二和第三個值。[1,2,3]的第0個元素為1,[4,5,6]的第1個元素為5,[7,8,9]的第2個元素為9,所以索引[[0],[1],[2]]的返回值為[1,5,9],同樣地,索引[[0],[0],[0]]的返回值為[1,4,7].

https://www.tensorflow.org/api_docs/python/tf/batch_gather

在深度學習的模型訓練中,有時候需要對一個batch的數據進行類似於tf.gather_nd的操作,但tensorflow中並沒有tf.batch_gather_nd之類的操作,此時需要tf.map_fn和tf.gather_nd結合來實現上述操作。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM