tensorflow expand_dims和squeeze


  有時我們會碰到升維或降維的需求,比如現在有一個圖像樣本,形狀是 [height, width, channels],我們需要把它輸入到已經訓練好的模型中做分類,而模型定義的輸入變量是一個batch,即形狀為 [batch_size, height, width, channels],這時就需要升維了。tensorflow提供了一個方便的升維函數:expand_dims,參數定義如下:

  tf.expand_dims(input, axis=None, name=None, dim=None)

  參數說明:

  input:待升維的tensor

  axis:插入新維度的索引位置

  name:輸出tensor名稱

  dim: 一般不用

  

import tensorflow as tf

sess = tf.Session()

t = tf.constant([1, 2, 3], dtype=tf.int32)

t.get_shape()
# TensorShape([Dimension(3)])

tf.expand_dims(t, 0).get_shape()
# TensorShape([Dimension(1), Dimension(3)])

tf.expand_dims(t, 1).get_shape()
# TensorShape([Dimension(3), Dimension(1)])

 

  squeeze正好執行相反的操作:刪除大小是1的維度

  tf.squeeze(input, squeeze_dims=None, name=None)

 

  input:  待降維的張量

  sequeeze_dims: list[int]類型,表示需要刪除的維度索引。默認為[],即刪除所以大小為1的維度

# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
shape(squeeze(t)) ==> [2, 3]
Or, to remove specific size 1 dimensions:
 
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]

 

  在處理tensor的時候合理使用這兩個函數,能極大的提高效率。例如處理輸入樣本、執行向量與矩陣的點乘等情況。

 

參考:https://blog.csdn.net/qq_31780525/article/details/72280284

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM