tf.transpose函數解析
覺得有用的話,歡迎一起討論相互學習~
tf.transpose(a, perm = None, name = 'transpose')
解釋
- 將a進行轉置,並且根據perm參數重新排列輸出維度。這是對數據的維度的進行操作的形式。
Details
-
圖像處理時數據集中存儲數據的形式為[channel,image_height,image_width],在tensorflow中使用CNN時我們需要將其轉化為[image_height,image_width,channel]的形式,只需要使用tf.transpose(input_data,[1,2,0])
-
輸出數據tensor的第i維將根據perm[i]指定。比如,如果perm沒有給定,那么默認是perm = [n-1, n-2, ..., 0],其中rank(a) = n。
-
默認情況下,對於二維輸入數據,其實就是常規的矩陣轉置操作。
Example
input_data.dims = (1, 4, 3)
perm = [1, 2, 0]
因為 output_data.dims[0] = input_data.dims[ perm[0] ]
因為 output_data.dims[1] = input_data.dims[ perm[1] ]
因為 output_data.dims[2] = input_data.dims[ perm[2] ]
所以得到 output_data.dims = (4, 3, 1)
output_data.dims = (4, 3, 1)
代碼演示
import tensorflow as tf
sess = tf.Session()
input_data = tf.constant([[1, 2, 3], [4, 5, 6]])
print(sess.run(tf.transpose(input_data)))
# [[1 4]
# [2 5]
# [3 6]]
print(sess.run(input_data))
# [[1 2 3]
# [4 5 6]]
print(sess.run(tf.transpose(input_data, perm=[1, 0])))
# [[1 4]
# [2 5]
# [3 6]]
input_data = tf.constant([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]])
print('input_data shape: ', sess.run(tf.shape(input_data)))
# [1, 4, 3]
output_data = tf.transpose(input_data, perm=[1, 2, 0])
print('output_data shape: ', sess.run(tf.shape(output_data)))
# [4, 3, 1]
print(sess.run(output_data))
# [[[ 1]
# [ 2]
# [ 3]]
# [[ 4]
# [ 5]
# [ 6]]
#
# [[ 7]
# [ 8]
# [ 9]]
#
# [[10]
# [11]
# [12]]]
"""形式為:[[[],[],[]],[[],[],[]],[[],[],[]],[[],[],[]]]"""
"""輸入參數:
● a: 一個Tensor。
● perm: 一個對於a的維度的重排列組合。
● name:(可選)為這個操作取一個名字。
輸出參數:
● 一個經過翻轉的Tensor。"""
perm沒有指定的情況下transpose函數的結果
input_data = tf.constant([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]])
print('input_data shape: ', sess.run(tf.shape(input_data)))
# [1, 4, 3]
output_data = tf.transpose(input_data)
print('output_data shape: ', sess.run(tf.shape(output_data)))
# output_data shape: [3 4 1]
sess.close()




