tensor維度變換


維度變換是tensorflow中的重要模塊之一,前面mnist實戰模塊我們使用了圖片數據的壓平操作,它就是維度變換的應用之一。

在詳解維度變換的方法之前,這里先介紹一下View(視圖)的概念。所謂View,簡單的可以理解成我們對一個tensor不同維度關系的認識。舉個例子,一個[ b,28,28,1 ]的tensor(可以理解為mnist數據集的一組圖片),對於這樣一組圖片,我們可以有一下幾種理解方式:

(1)按照物理設備儲存結構,即一整行的方式(28*28)儲存,這一行有連續的784個數據,這種理解方式可以用[ b,28*28 ]表示

(2)按照圖片原有結構儲存,即保留圖片的行列關系,以28行28列的數據理解,這種方式可以用[ b,28,28 ]表示

(3)將圖片分塊(比如上下兩部分),這種理解方式與第二種類似,只是將一張圖變為兩張,這種方式可以用[ b,2,14*28 ]表示

(4)增加channel通道,這種理解方式也與第二種類似,只是這種對rgb三色圖區別更明顯,可以用[ b,28 28,1 ]表示

通過維度的等價變換,就可以實現思維上View的轉換

維度變換的方式:

方式1:tf.reshape(可通過破壞維度之間的關系改變tensor的維度,但不會改變原有數據的存儲順序)

a = tf.random.normal([4,28,28,3])
print(a.shape)
print(tf.reshape(a,[4,784,3]).shape)
print(tf.reshape(a,[4,-1,3]).shape)
print(tf.reshape(a,[4,784*3]).shape)
print(tf.reshape(a,[4,-1]).shape)

 

但是reshape在恢復已經reshape的數據時會出現問題,比如[ 4,28,28,3 ]的數據reshape成[ 4,784,3 ]的數據要想再恢復成以前的樣子,就需要記錄下以前的content(內容)信息,如果記錄過程出現錯誤(如width和height維度記反或者數值記錯),就會導致恢復不成想要的樣子。

方式2:tf.transpose  (content的變換)

a = tf.random.normal([4,3,2,1])
print(a.shape)
print(tf.transpose(a).shape)
print(tf.transpose(a,perm=[0,1,3,2]).shape)

 

通過這種變換方式會徹底改變原來圖片數據的維度關系,在經過transpose之后,再用reshape變換得到的數據是基於新的content(transpose之后)進行的變換,所以reshape時要記錄新的content信息,不然會導致數據混亂甚至程序異常。

方式3:tf.expand_dims、tf.squeeze (增加和減少維度)

a = tf.random.normal([4,35,8])
# tf.expand_dims增加維度
# 若給定axis>0,則在給定軸前增加維度,若給定axis<0,則在給定軸后增加維度
print(tf.expand_dims(a,axis=0).shape)
print(tf.expand_dims(a,axis=3).shape)
print(tf.expand_dims(a,axis=-1).shape)
print(tf.expand_dims(a,axis=-4).shape)

# tf.squeeze用於減少維度
print(tf.squeeze(tf.zeros([1,2,1,1,3])).shape)
a = tf.zeros([1,2,1,3])
print(tf.squeeze(a,axis=0).shape)
print(tf.squeeze(a,axis=2).shape)
print(tf.squeeze(a,axis=-2).shape)
print(tf.squeeze(a,axis=-4).shape)

 

需要注意的是,squeeze只能減少維度值為1的維度,且axis必須為已存在的軸索引

當前主流的神經網絡之一SE-NET就通過巧妙的使用expand和squeeze模塊,使得模型准確率更上一個台階

SE-net的github源碼地址:https://github.com/hujie-frank/SENet


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM