有時我們會碰到升維或降維的需求,比如現在有一個圖像樣本,形狀是 [height, width, channels],我們需要把它輸入到已經訓練好的模型中做分類,而模型定義的輸入變量是一個batch,即形狀為 [batch_size, height, width, channels],這時就需要升維了。tensorflow提供了一個方便的升維函數:expand_dims,參數定義如下:
tf.expand_dims(input, axis=None, name=None, dim=None)
參數說明:
input:待升維的tensor
axis:插入新維度的索引位置
name:輸出tensor名稱
dim: 一般不用
import tensorflow as tf sess = tf.Session() t = tf.constant([1, 2, 3], dtype=tf.int32) t.get_shape() # TensorShape([Dimension(3)]) tf.expand_dims(t, 0).get_shape() # TensorShape([Dimension(1), Dimension(3)]) tf.expand_dims(t, 1).get_shape() # TensorShape([Dimension(3), Dimension(1)])
squeeze正好執行相反的操作:刪除大小是1的維度
tf.squeeze(input, squeeze_dims=None, name=None)
input: 待降維的張量
sequeeze_dims: list[int]類型,表示需要刪除的維度索引。默認為[],即刪除所以大小為1的維度
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] shape(squeeze(t)) ==> [2, 3] Or, to remove specific size 1 dimensions: # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
在處理tensor的時候合理使用這兩個函數,能極大的提高效率。例如處理輸入樣本、執行向量與矩陣的點乘等情況。
參考:https://blog.csdn.net/qq_31780525/article/details/72280284
