本文主要介紹Pytorch中Tensor的儲存機制,在搞懂了Tensor在計算機中是如何存儲之后我們會進一步來探究tensor.view()、tensor.reshape()、tensor.reszie_(),她們都是改變了一個tensor的“形狀”,但是他們之間又有着些許的不同,這些不同常常會導致我們程序之中出現很多的BUG。
一、Tensor的儲存機制
tensor在電腦的儲存,分為兩個部分(也就是說一個tensor占用了兩個內存位置),一個內存儲存了這個tensor的形狀size、步長stride、數據的索引等信息,我們把這一部分稱之為頭信息區(Tensor);另一個內存儲的就是真正的數據,我們稱為存儲區 (Storage)。換句話說,一旦定義了一個tensor,那這個tensor將會占據兩個內存位置,用於存儲。
要注意,如果我們把一個tensorA進行切片,截取,修改之后通過"="賦值給B,那么這個時候tensorB其實是和tensorA是共享存儲區 (Storage),唯一不同的是頭信息區(Tensor)不同。下面我們直接看代碼來理解。其中tensor.storage().data_ptr()是用於獲取tensor儲存區的首元素內存地址的。
A = torch.arange(5) # tensor([0, 1, 2, 3, 4])
B = A[2:] # 對A進行截取獲得:tensor([2, 3, 4])
print(A)
print(B)
tensor([0, 1, 2, 3, 4]) tensor([2, 3, 4])
print(A.storage().data_ptr())
print(B.storage().data_ptr())
2076006947200
2076006947200
我們可以很直觀的看到,A和B的儲存區的內存地址是一樣的,因此她們是共享數據的,下面這個例子更加直觀。
import torch
A = torch.arange(5) # tensor([0, 1, 2, 3, 4])
B = A[2:] # 對A進行截取獲得:tensor([2, 3, 4])
B[1] = 100 # 修改B的第2位置元素為100
print(A)
print(B)
tensor([ 0, 1, 2, 100, 4]) tensor([ 2, 100, 4])
因此我們可以得出結論,通過=直接賦值的操作其實就是“淺拷貝”(這里注意和list的切片區分,list使用A[2:],是可以得到新的一個list的)
二、tensor的stride()屬性、storage_offset()屬性
為了更好的解釋tensor的reshape(),以及view()的操作,我們還需要了解下tensor的stride屬性。剛才上面我們提到了,tensor為了節約內存,很多操作其實都是在更改tensor的頭信息區(Tensor),因為頭信息區里面包含了如何組織數據,以及從哪里開始組織。其中stride()和storage_offset()屬性分別代表的就是步長以及初始偏移量。
storage_offset()屬性
表示tensor的第一個元素與真實存儲區(storage)的第一個元素的偏移量。例如下面的例子:
import torch
A = torch.arange(5)
B = A[2:]
C = A[1:]
print(A)
print(B)
print(C)
tensor([0, 1, 2, 3, 4]) tensor([2, 3, 4]) tensor([1, 2, 3, 4])
print(B.storage_offset())
print(C.storage_offset())
2
1
我們可以看到tensorB和tensorC都是從A切片而來的,她們倆的存儲區 (Storage)是和A共享的,只不過B的第一個元素,與存儲區 (Storage)的首元素相差了2個位置(也就是儲存區的index=2開始),C的第一個元素與存儲區 (Storage)的首元素相差了1個位置。
stride()屬性
這個屬性比較難理解,直接翻譯官方文檔就是:stride是在指定維度dim中從一個元素跳到下一個元素所必需的步長。直接上例子:
import torch
A = torch.rand(2, 3) # 生成2*3的隨機數
print(A)
print(A.storage()) # 打印A的儲存區真實的數據
打印A: tensor([[0.8438, 0.2782, 0.9584], [0.2089, 0.0259, 0.3666]]) 0.8437800407409668
0.2781521677970886
0.9583932757377625
0.2088671326637268
0.025857746601104736
0.366576611995697 [torch.FloatStorage of size 6]
print(A.stride())
(3, 1)
主要是理解這個(3,1)指的是什么意思。這里的3指的是A[i][j]到A[i+1][j]這兩個數字在存儲區真實數據排列中是相差3的(例如A[0][0]=0.8438與A[1][0]=0.2089這兩個數字在儲存區中位次相差了3);這里的1是指A[i][j]與A[i][j+1]這兩個數字在儲存區的真實數據排列中相差1(例如A[0][0]=0.8438與A[0][1]=0.2781這兩個數字在儲存區中位次相差1)。如果還沒有理解,加下來我們試一下對於3維數據看看他們的stride()屬性。
import torch
A = torch.rand(2, 3, 4) # 生成2*3*4的隨機數
print("打印A:",A)
print(A.storage()) # 打印A的儲存區真實的數據
打印A: tensor([[[0.4303, 0.7474, 0.8649, 0.5006], [0.2716, 0.9966, 0.7765, 0.6737], [0.5515, 0.2274, 0.9791, 0.1940]], [[0.6401, 0.7746, 0.5124, 0.0258], [0.8576, 0.9118, 0.9504, 0.4675], [0.9359, 0.0687, 0.2457, 0.3604]]]) 0.4302864074707031
0.747403085231781
0.8648527264595032
0.500649631023407
0.2716004252433777
0.9965775609016418
0.7765441536903381
0.6737198233604431
0.5515168905258179
0.2273930311203003
0.9791405200958252
0.19399094581604004
0.6401097774505615
0.7746065855026245
0.512383759021759
0.02578103542327881
0.8575518727302551
0.911821186542511
0.9503545165061951
0.4674733877182007
0.9358749389648438
0.06866037845611572
0.24573636054992676
0.3603515625
print(A.stride())
(12, 4, 1)
輸出有點長,大家對照着看,由於我們A的size是3維度的,因此我們A.stride()也是個三元組,那如果A是4維呢?(A.stride()一定就是4元組了)。這里的12表示就是A[i][j][k]與A[i+1][j][k]這兩個數字在真實儲存區的數據排布中相差12,大家可以對照的找幾個數字試試。同樣的道理這里的4表示A[i][j][k]與A[i][j+1][k]這兩個數字在真實儲存區的數據排布中相差4。最后1表示什么我就不說啦。
好了終於說完這個很難的知識點了,接下來就進入正題,view()、reshape()、reszie_()三者的關系和區別。
三、view()、reshape()、reszie_()三者的關系和區別
其中view()和reshape()是官方比較推薦使用的方式,而resize_()官方在文檔中說到不太推薦使用,具體原因一會說到。這三個方法都是可以完成對以一個tensor重新排列,沒錯是重新排列,其實她們本質上都沒有改變tensor的存儲區 (Storage)的真實數據的排列(除了一些特殊情況下會使得存儲區發生改變,這就是她們間的區別)。
view()
從字面上來說就是"視圖"的意思,就是把存儲區 (Storage)的真實數據,根據某種排列方式”展示“給你看罷了,也就是僅僅改變了頭信息區(Tensor),真實數據的儲存地址是沒有改變的。直接上例子。
import torch
A = torch.arange(6)
B = A.view(2,3)
print(A)
print(B)
tensor([0, 1, 2, 3, 4, 5]) tensor([[0, 1, 2], [3, 4, 5]])
print(A.storage().data_ptr())
print(B.storage().data_ptr())
1881582170752
1881582170752
可以看到,A和B的真實數據的內存地址都是一樣的,下面我們進一步打印一下A,B兩個tensor真實數據的排列。
print(A.storage())
print(B.storage())
0 1
2
3
4
5 [torch.LongStorage of size 6] 0 1
2
3
4
5 [torch.LongStorage of size 6]
可以看到,是完全一樣的。更進一步打印一下A,B的stride()屬性
print(A.stride())
print(B.stride())
(1,) (3, 1)
沒問題和前面說的是一樣的。
總結一下,view()函數主要就是更改了tensor中的stride()屬性,這樣從而影響了tensor的顯示,但是從本質上來說A,B還是共用真實數據的存儲區 (Storag)的。
reshape()
為了解釋view()和reshape()的區別,我們還需要知道一個知識:tensor的連續性。tensor又不是函數哪里來什么連續性?其實tensor的連續性說的就是stride()屬性和size()屬性(tensor維度)之間的關系。
前一小結已經說了對於一個高維的tensor,stride()指的是:指定維度dim中從一個元素跳到下一個元素所必需的步長。一般來說我們最后一個維度步長應該是1(其實我們前面的例子我們應該也能發現,例子中所有tensor.stride()返回的元組最后一個元素都是1),對吧,因為是按順序排列的嘛。但是當一個tensor涉及到轉置(tensor.t(),tensor.transpose(),tensor.permute())這些操作都會使得tensor失去連續性這個性質。我們直接來看看例子吧。
a = torch.arange(6).view(2, 3)
b = a.t()
c = a.transpose(1,0)
d = a.permute(1,0)
print('b是:',b)
print('c是:',c)
print('d是:',d)
b是: tensor([[0, 3], [1, 4], [2, 5]]) c是: tensor([[0, 3], [1, 4], [2, 5]]) d是: tensor([[0, 3], [1, 4], [2, 5]])
print(a.stride())
print(b.stride())
print(c.stride())
print(d.stride())
(3, 1) (1, 3) (1, 3) (1, 3)
這里我就不驗證她們是不是同一個存儲區 (Storage)了,大家下來可以驗證下(其實就是同一個)。我們可以看到b,c,d三個tensor的stride()屬性和a是不一樣的,根據stride()的定義大家應該是很容易知道b,c,d返回的stride()是什么意思吧。那為什么說b,c,d的tensor就不連續了呢?是因為她們不滿足張量的連續性條件了。連續性條件如下:
這是什么意思呢?拿b舉例就是,b的stride=(3,1),b的size=(3,2),那么stride[0] != stride[1] * size[1]的,因此b是不滿足連續性條件的。如果從直觀上來感覺來"連續"的意思就是,“我”旁邊的數字就應該是“我”真實儲存區旁邊的數據,例如b[0][0]=0,但是b[0][1]=3,0和3這兩個數字在真實的存儲區 (Storage)不是挨着的啊,所以叫做不連續。
那不滿足連續性有什么后果呢?后果就是不滿足連續性的tensor是無法使用view()方法的。換句話說,上面例子中的b,c,d都無法再使用view()方法了。
e = b.view(1,6)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
c,d大家自己下來試一試。所以對於一個tensor是不是連續就意味着他能不能使用view()方法。
那有什么辦法讓b使用view()方法呢?那就是把b連續化(使用tensor.contiguous()方法)。上例子。
a = torch.arange(6).view(2, 3)
b = a.t()
c = b.contiguous()
print(a.storage())
print(c.storage())
0 1
2
3
4
5 [torch.LongStorage of size 6] 0 3
1
4
2
5 [torch.LongStorage of size 6]
print(a.storage().data_ptr())
print(c.storage().data_ptr())
1881582182144
1881582172928
其實tensor.contiguous()方法是創造了一個新的tensor(全新的,連存儲區都不共用的tensor),這里的c就是從b得到的連續的tensor了,大家可以打印下c.stride(),會得到(2,1),這樣再根據c的size就能發現,c是滿足上面提到的連續性公式的。
了解以上知識之后,reshape()和view()的差別就來了,view()是沒法對非連續性的tensor使用的(會報錯),但是reshape()是可以對非連續性tensor使用的。換句話說
- 當tensor滿足連續性要求時,reshape() = view(),和原來tensor共用內存
- 當tensor不滿足連續性要求時,reshape() = contiguous() + view(),會產生新的存儲區的tensor,與原來tensor不共用內存
這就是view()和reshape()的差別了。
reszie_()
那這一個又和前面那倆有啥關系的呢?從官方文檔上來說,它是不希望我們使用這個resize_()的,如圖。
前面說到的reshape和view都必須要用到全部的原始數據,比如你的原始數據只有12個,無論你怎么變形都必須要用到12個數字們,不能多不能少。因此你就不能把只有12個數字的tensor強行reshape成2*5的維度的tensor。但是resize_()可以做到,無論你存儲區原始有多少個數字,我都能變成你想要的維度,數字不夠怎么辦?隨機產生湊!數字多了怎么辦?就取我需要的部分!上例子。
多說一句a.resize_()是會改變a的喲,換句話說,a.resize_(2,3)之后,a就不再是1*7的維度了,而是2*3的維度了。但是a的儲存區還是原來的儲存區
a = torch.arange(7)
print("變換前a的儲存區地址:",a.storage().data_ptr())
b = a.resize_(2,3)
print('這是新的a:',a)
變換前a的儲存區地址: 1881579251648 這是新的a: tensor([[0, 1, 2], [3, 4, 5]])
print(a.storage())
print(b.storage())
0 1
2
3
4
5
6 [torch.LongStorage of size 7] 0 1
2
3
4
5
6 [torch.LongStorage of size 7]
print('變換后a的儲存區地址',a.storage().data_ptr())
print(b.storage().data_ptr())
變換后a的儲存區地址 1881579251648
1881579251648
你會發現盡管a的”長相“(數字個數也從7個變成了6個)被改變了,但是存儲區依舊是沒變的(要注意到真實存儲區的個數也沒變喲還是7個),因此我們可以說resize_()再進行變換的時候如果數字多余了,會截取我們需要的數據量,多余的數據量並沒有被舍棄。
再來看看,當我reszie_多於原來的數據的時候發生什么。
a = torch.arange(7)
print("變換前a的儲存區地址:",a.storage().data_ptr())
b = a.resize_(3,4)
print(a.storage())
print(b.storage())
變換前a的儲存區地址: 1881579250944
0 1
2
3
4
5
6
7667809
6815836 [torch.LongStorage of size 9]
0 1
2
3
4
5
6
7667809
6815836 [torch.LongStorage of size 9]
print('變換后a的儲存區地址',a.storage().data_ptr())
print(b.storage().data_ptr())
變換后a的儲存區地址 1881582026048
1881582026048
這個時候resize_()前后a的儲存區地址是發生了變化的喲。
下一個問題:resize_()可不可以對不連續的tensor使用呢?
答案是可以,並且並不會改變原來tensor的內存。當tensor是不連續的時候,采用reshape()會生成個新的存儲區的,采用resize_()則不會改變存儲區。那這兩者有啥區別呢?其實很好解釋,reshape是尊重tensor,把存儲區改了來將就tensor的reshape的長相,並使得連續。而resize_是:不改存儲區,但是“用戶”又想要看到想看到的長相,行,那我就把存儲區的數按照你想看到的長相排列吧。直接上例子。
import torch a = torch.arange(6).view(2, 3) b = a.t() #b是這個樣子的:tensor([[0, 3], # [1, 4], # [2, 5]]) c = b.reshape(1,6) e = b.resize_(1,6) print("c的存儲區:",c.storage().data_ptr()) print('e的存儲區:',e.storage().data_ptr())
c的存儲區: 2237602017664 e的存儲區: 2237602025472
print("c的存儲區真實數據排布:",c.storage()) print("e的存儲區真實數據排布:",e.storage())
c的存儲區真實數據排布: 0 3 1 4 2 5 [torch.LongStorage of size 6] e的存儲區真實數據排布: 0 1 2 3 4 5 [torch.LongStorage of size 6]
print('我是c:',c) print('我是e:',e)
我是c: tensor([[0, 3, 1, 4, 2, 5]]) 我是e: tensor([[0, 1, 2, 3, 4, 5]])
可以很直觀的看出來,如果tensor是不連續的時候,reshape和resize_的差別了吧。
四、總結
最后總結一下view()、reshape()、reszie_()三者的關系和區別。
- view()只能對滿足連續性要求的tensor使用。
- 當tensor滿足連續性要求時,reshape() = view(),和原來tensor共用內存。
- 當tensor不滿足連續性要求時,reshape() = contiguous() + view(),會產生新的存儲區的tensor,與原來tensor不共用內存。
- resize_()可以隨意的獲取任意維度的tensor,不用在意真實數據的個數限制,但是不推薦使用。
參考博客:
PyTorch:view() 與 reshape() 區別詳解_Flag_ing的博客-CSDN博客_reshape和view
pytorch筆記(一)——tensor的storage()、stride()、storage_offset()_Zoran的博客-CSDN博客_pytorch stride()