Pytorch-tensor的維度變化


引言

本篇介紹tensor的維度變化。

維度變化改變的是數據的理解方式!

  • view/reshape:大小不變的條件下,轉變shape
  • squeeze/unsqueeze:減少/增加維度
  • transpose/t/permute:轉置,單次/多次交換
  • expand/repeat:維度擴展

view reshape

  • 在pytorch0.3的時候,默認是view .為了與numpy一致0.4以后增加了reshape。
  • 損失維度信息,如果不額外存儲/記憶的話,恢復時會出現問題。
  • 執行view/reshape是有一定的物理意義的,不然不會這樣做。
  • 保證tensor的size不變即可/numel()一致/元素個數不變。
  • 數據的存儲/維度順序非常非常非常重要
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
In[4]: a = torch.rand(4,1,28,28)
In[5]: a.shape
Out[5]: torch.Size([4, 1, 28, 28])
In[6]: a.view(4,28*28) # 4, 1*28*28 將后面的進行合並/合並通道,長寬,忽略了通道信息,上下左右的空間信息。適合全連接層。
Out[6]:
tensor([[0.1483, 0.6321, 0.8864, ..., 0.0646, 0.4791, 0.0892],
[0.5868, 0.5278, 0.8514, ..., 0.0682, 0.7815, 0.2724],
[0.4945, 0.4957, 0.0047, ..., 0.4253, 0.4135, 0.1234],
[0.0618, 0.4257, 0.1960, ..., 0.1377, 0.5776, 0.4071]])
In[7]: a.view(4,28*28).shape
Out[7]: torch.Size([4, 784])
In[8]: a.view(4*28, 28).shape # 合並batch,channel,行合並 放在一起為N [N,28] 每個N,剛好有28個像素點,只關心一行數據
Out[8]: torch.Size([112, 28])
In[9]: a.view(4*1,28,28).shape # 4張疊起來了
Out[9]: torch.Size([4, 28, 28])
In[10]: b = a.view(4,784) # a原來的維度信息是[b,c,h,w],但a這樣賦值后,它是恢復不到原來的
In[11]: b.view(4,28,28,1) # logic Bug # 語法上沒有問題,但邏輯上 [b h w c] 與以前是不對應的。
a.view(4,783)
RuntimeError: shape '[4, 783]' is invalid for input of size 3136

squeeze 與 unsqueeze

unsqueeze

  • unsqueeze(index) 拉伸(增加一個維度) (增加一個組別)
  • 參數的范圍是 [-a.dim()-1, a.dim()+1) 如下面例子中范圍是[-5,5)
  • -5 –> 0 … -1 –> 4 這樣的話,0表示在前面插入,-1表示在后面插入,正負會有些混亂,所以推薦用正數。
  • 0與正數,就是在xxx前面插入。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
In[17]: a.shape
Out[17]: torch.Size([4, 1, 28, 28])
In[18]: a.unsqueeze(0).shape # 在0的前面插入一個維度
Out[18]: torch.Size([1, 4, 1, 28, 28]) # 理解上就是在batch的基礎上增加了組。
In[19]: a.unsqueeze(-1).shape # 在-1之后插入一個維度
Out[19]: torch.Size([4, 1, 28, 28, 1]) # 理解上可能增加一個方差之類的
In[20]: a.unsqueeze(4).shape
Out[20]: torch.Size([4, 1, 28, 28, 1])
In[21]: a.unsqueeze(-4).shape
Out[21]: torch.Size([4, 1, 1, 28, 28])
In[22]: a.unsqueeze(-5).shape
Out[22]: torch.Size([1, 4, 1, 28, 28])
In[23]: a.unsqueeze(-6).shape
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got -6)

In[24]: a = torch.tensor([1.2,2.3])
In[27]: a.shape
Out[27]: torch.Size([2])
In[25]: a.unsqueeze(-1) # 維度變成 [2,1] 2行1列
Out[25]:
tensor([[1.2000],
[2.3000]])
In[26]: a.unsqueeze(0)
Out[26]: tensor([[1.2000, 2.3000]]) # 維度變成 [1,2] 1行2列

實際案例

給一個bias(偏置),bias相當於給每個channel上的所有像素增加一個偏置

為了做到 f+b 我們需要改變b的維度

1
2
3
4
5
In[28]: b = torch.rand(32)
In[29]: f = torch.rand(4,32,14,14)
In[30]: b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
In[31]: b.shape
Out[31]: torch.Size([1, 32, 1, 1])

后面進一步擴張到 [4,32,14,14]

queeze

  • squeeze(index) 當index對應的dim為1,就產生作用。
  • 不寫參數,會擠壓所有維度為1的。
1
2
3
4
5
6
7
8
9
10
11
12
In[38]: b.shape
Out[38]: torch.Size([1, 32, 1, 1])
In[39]: b.squeeze().shape # 默認將所有維度為1的進行擠壓 這32個channel,每個channel有一個值
Out[39]: torch.Size([32])
In[40]: b.squeeze(0).shape
Out[40]: torch.Size([32, 1, 1])
In[41]: b.squeeze(-1).shape
Out[41]: torch.Size([1, 32, 1])
In[42]: b.squeeze(1).shape
Out[42]: torch.Size([1, 32, 1, 1])
In[43]: b.squeeze(-4).shape
Out[43]: torch.Size([32, 1, 1])

expand / repeat

  • Expand:broadcasting (推薦)
    • 只是改變了理解方式,並沒有增加數據
    • 在需要的時候復制數據
  • Reapeat:memory copied
    • 會實實在在的增加數據

上面提到的b [1, 32, 1, 1] f[ 4, 32, 14, 14 ]

目標是將b的維度變成與f相同的維度。

expand

  • 擴展(expand)張量不會分配新的內存,只是在存在的張量上創建一個新的視圖(view)
1
2
3
4
5
6
7
8
9
In[44]: a = torch.rand(4,32,14,14)
In[45]: b.shape
Out[45]: torch.Size([1, 32, 1, 1]) # 只有1-->N才是可行的, 3 -> N 是需要規則的
In[46]: b.expand(4,32,14,14).shape
Out[46]: torch.Size([4, 32, 14, 14])
In[47]: b.expand(-1,32,-1,-1).shape # -1表示這個維度不變
Out[47]: torch.Size([1, 32, 1, 1])
In[48]: b.expand(-1,32,-1,-4).shape # -4這里是一個bug,沒有意義,最新版已經修復了
Out[48]: torch.Size([1, 32, 1, -4])

repeat

  • 主動復制原來的。
  • 參數表示的是要拷貝的次數/是原來維度的倍數
  • 沿着特定的維度重復這個張量,和expand()不同的是,這個函數拷貝張量的數據。
1
2
3
4
5
6
7
8
9
10
11
In[49]: b.shape
Out[49]: torch.Size([1, 32, 1, 1])
In[50]: b.repeat(4,32,1,1).shape
Out[50]: torch.Size([4, 1024, 1, 1])
In[51]: b.repeat(4,1,1,1).shape
Out[51]: torch.Size([4, 32, 1, 1])
In[52]: b.repeat(4,1,32,32)
In[53]: b.repeat(4,1,32,32).shape
Out[53]: torch.Size([4, 32, 32, 32])
In[55]: b.repeat(4,1,14,14).shape # 這樣就達到目標了
Out[55]: torch.Size([4, 32, 14, 14])

轉置

.t

轉置操作

  • .t 只針對 2維矩陣
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
a = torch.randn(3,4)
a.t().shape
Out[58]: torch.Size([4, 3])
In[60]: a
Out[60]:
tensor([[ 0.5629, -0.5085, -0.3371, 1.2387],
[ 0.2142, -1.7846, 0.2297, 1.7797],
[-0.3197, 0.6116, 0.3791, 0.9218]])
In[61]: a.t()
Out[61]:
tensor([[ 0.5629, 0.2142, -0.3197],
[-0.5085, -1.7846, 0.6116],
[-0.3371, 0.2297, 0.3791],
[ 1.2387, 1.7797, 0.9218]])
b.t()
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

transpose

  • 在結合view使用的時候,view會導致維度順序關系變模糊,所以需要人為跟蹤。
  • 錯誤的順序,會導致數據污染
  • 一次只能兩兩交換
  • contiguous
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 由於交換了1,3維度,就會變得不連續,所以需要用contiguous,來吧數據變得連續。
In[17]: a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view()

In[8]: a = torch.randn(4,3,32,32)
In[9]: a.shape
Out[9]: torch.Size([4, 3, 32, 32])
In[10]: a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
#[b c h w] 交換1,3維度的數據 [b w h c],再把后面的三個連在一起,展開后變為 [b c w h] 導致和原來的順序不同,造成數據污染!!!
In[11]: a1.shape
Out[11]: torch.Size([4, 3, 32, 32])
In[12]: a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
# [b c h w] -> [b w h c] -> [b w h c] -> [b c h w] 和原來順序相同。
In[13]: a2.shape
Out[13]: torch.Size([4, 3, 32, 32])
# 驗證向量一致性
In[14]: torch.all(torch.eq(a,a1))
Out[14]: tensor(0, dtype=torch.uint8)
In[15]: torch.all(torch.eq(a,a2))
Out[15]: tensor(1, dtype=torch.uint8)

permute

  • 會打亂內存順序,待補充!!!
  • 由於transpose一次只能兩兩交換,所以變換后在變回去至少需要兩次操作,而permute一次就好。例如對於[b,h,w,c]
  • [b,h,w,c]是numpy存儲圖片的格式,需要這一步才能導出numpy
1
2
3
4
5
6
7
8
9
10
In[18]: a = torch.rand(4,3,28,28)
In[19]: a.transpose(1,3).shape # [b c h w] -> [b w h c] h與w的順序發生了變換,導致圖像發生了變化
Out[19]: torch.Size([4, 28, 28, 3])
In[20]: b = torch.rand(4,3,28,32)
In[21]: b.transpose(1,3).shape
Out[21]: torch.Size([4, 32, 28, 3])
In[22]: b.transpose(1,3).transpose(1,2).shape
Out[22]: torch.Size([4, 28, 32, 3]) # [b,h,w,c]是numpy存儲圖片的格式,需要這一步才能導出numpy
In[23]: b.permute(0,2,3,1).shape
Out[23]: torch.Size([4, 28, 32, 3])

Broadcast

自動擴展

  • 維度擴展,自動調用expand
  • without copying data ,不需要拷貝數據。

核心思想

  • 在前面插入1維
  • 將size 1 擴展成相同 size 的維度

例子:

  • 對於 feature maps : [4, 32, 14, 14],想給它添加一個偏置Bias
  • Bias:[32] –> [32, 1 , 1] (這里是手動的) => [1, 32, 1, 1] => [4, 32, 14, 14]
  • 目標:當Bias和feature maps的size一樣時,才能執行疊加操作!!!

Why broadcasting?

就像下圖表示的一樣:我們希望進行如下的幾種計算,但需要滿足數學上的約束(size相同),為了節省人們為滿足數學上的約束而手動復制的過程,而產生的Broadcast,它節省了大量的內容消耗。

Broadcast

  • 第二行數據中 [3] => [1, 3] => [4, 3] (行復制了4次)
  • 第三行數據中
    • [4,1] => [4, 3] (列復制了3次)
    • [1,3] => [4, 3] (行復制了4次)
  • broadcast = unsqueze(插入新維度) + expand(將1dim變成相同維度)

例子:

  • 有這樣的數據 [class, students, scores],具體是4個班,每個班32人,每人8門課程[4, 32, 8] 。
  • 考試不理想,對於這組數據我們需要為每一位同學的成績加5分
  • 要求: [4, 32, 8] + [4, 32, 8]
  • 實際上:[4, 32, 8] + [5.0]
  • 操作上:[1] =>(unsqueeze) [1, 1, 1] =>(expand_as) [4, 32, 8],這樣需要寫3個接口。
  • 所以才會有 broadcast!!

內存分析:

  • [4, 32, 8] => 1024
  • [5.0] => 1 如果是手動復制的話,內存消耗將變為原來的1024倍

使用條件?

A [ 大維度 —> 小維度 ]

從最后一位(最小維度)開始匹配,如果維度上的size是0,1或相同,則滿足條件,看下一個維度,直到都滿足條件為止。

  • 如果當前維度是1,擴張到相同維度
  • 如果沒有維度,插入一個維度並擴張到相同維度
  • 當最小維度不匹配的時候是沒法使用broadcastiong,如共有8門課程,但只給了4門課程的變化,這樣就會產生歧義。

note:小維度指定,大維度隨意

小維度指定:假如英語考難了,只加英語成績 [0 0 5 0 0 0 0 0]

案例

情況一

A[4, 32, 14, 14]

B[1, 32, 1, 1] => [4,,32, 14, 14]

情況二

A[4, 32, 14, 14]

B[14, 14] => [1, 1, 14, 14] => [4, 32, 14, 14]

情況三

不符合條件

A[4, 32, 14, 14]

B[2, 32, 14, 14]

理解這種行為

  • 小維度指定,大維度隨意。小維度設定規則(加5分),大維度默認按照這個規則(通用)。
  • 維度為1才滿足條件,是為了保證公平(統一的規則)

常見使用情景

  • A [4, 3, 32, 32] b,c,h,w
  • +[32, 32] 疊加一個相同的feature map,做一些平移變換。相當於一個base(基底),
  • +[3, 1, 1] 針對 RGB 進行不同的補充,如R 0.5 、G 0 、B 0.3
  • +[1, 1, 1, 1] 對於所有的都加一個數值,抬高一下,如加0.5.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM