原文鏈接
tensorflow中取下標的函數包括:tf.gather , tf.gather_nd 和 tf.batch_gather。
1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)
indices必須是一維張量
主要參數:
- params:被索引的張量
- indices:一維索引張量
- name:返回張量名稱
返回值:通過indices獲取params下標的張量。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
tensor_c = tf.Variable([0,0],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(tensor_a,tensor_b)))
print(sess.run(tf.gather(tensor_a,tensor_c)))
上個例子tf.gather(tensor_a,tensor_b) 的值為[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值為[[1,2,3],[1,2,3]]
對於tensor_a,其第1個元素為[4,5,6],第2個元素為[7,8,9],第0個元素為[1,2,3],所以以[1,2,0]為索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],同樣的,以[0,0]為索引的值為[[1,2,3],[1,2,3]]。
https://www.tensorflow.org/api_docs/python/tf/gather
2.tf.gather_nd(params,indices,name=None)
功能和參數與tf.gather類似,不同之處在於tf.gather_nd支持多維度索引,即indices可以使多維張量。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值為[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值為[3,7].
對於tensor_a,下標[1,0]的元素為4,下標為[1,1]的元素為5,下標為[1,2]的元素為6,索引[1,0],[1,1],[1,2]]的返回值為[4,5,6],同樣的,索引[[0,2],[2,0]]的返回值為[3,7].
https://www.tensorflow.org/api_docs/python/tf/gather_nd
3.tf.batch_gather(params,indices,name=None)
支持對張量的批量索引,各參數意義見(1)中描述。注意因為是批處理,所以indices要有和params相同的第0個維度。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值為[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值為[1,4,7].
tensor_a的三個元素[1,2,3],[4,5,6],[7,8,9]分別對應索引元素的第一,第二和第三個值。[1,2,3]的第0個元素為1,[4,5,6]的第1個元素為5,[7,8,9]的第2個元素為9,所以索引[[0],[1],[2]]的返回值為[1,5,9],同樣地,索引[[0],[0],[0]]的返回值為[1,4,7].
https://www.tensorflow.org/api_docs/python/tf/batch_gather
在深度學習的模型訓練中,有時候需要對一個batch的數據進行類似於tf.gather_nd的操作,但tensorflow中並沒有tf.batch_gather_nd之類的操作,此時需要tf.map_fn和tf.gather_nd結合來實現上述操作。