Pytorch——Tensor的儲存機制以及view()、reshape()、reszie_()三者的關系和區別


  本文主要介紹Pytorch中Tensor的儲存機制,在搞懂了Tensor在計算機中是如何存儲之后我們會進一步來探究tensor.view()、tensor.reshape()、tensor.reszie_(),她們都是改變了一個tensor的“形狀”,但是他們之間又有着些許的不同,這些不同常常會導致我們程序之中出現很多的BUG。

一、Tensor的儲存機制

  tensor在電腦的儲存,分為兩個部分(也就是說一個tensor占用了兩個內存位置),一個內存儲存了這個tensor的形狀size、步長stride、數據的索引等信息,我們把這一部分稱之為頭信息區(Tensor);另一個內存儲的就是真正的數據,我們稱為存儲區 (Storage)。換句話說,一旦定義了一個tensor,那這個tensor將會占據兩個內存位置,用於存儲。

  要注意,如果我們把一個tensorA進行切片,截取,修改之后通過"="賦值給B,那么這個時候tensorB其實是和tensorA是共享存儲區 (Storage),唯一不同的是頭信息區(Tensor)不同。下面我們直接看代碼來理解。其中tensor.storage().data_ptr()是用於獲取tensor儲存區的首元素內存地址的。

A = torch.arange(5)  # tensor([0, 1, 2, 3, 4])
B = A[2:]            # 對A進行截取獲得:tensor([2, 3, 4])
print(A)
print(B)
tensor([0, 1, 2, 3, 4]) tensor([2, 3, 4])
print(A.storage().data_ptr())
print(B.storage().data_ptr())
2076006947200
2076006947200

  我們可以很直觀的看到,A和B的儲存區的內存地址是一樣的,因此她們是共享數據的,下面這個例子更加直觀。

import torch
A = torch.arange(5)  # tensor([0, 1, 2, 3, 4])
B = A[2:]            # 對A進行截取獲得:tensor([2, 3, 4])

B[1] = 100     # 修改B的第2位置元素為100
print(A)
print(B)
tensor([  0,   1,   2, 100,   4]) tensor([ 2, 100,   4])

  因此我們可以得出結論,通過=直接賦值的操作其實就是“淺拷貝”(這里注意和list的切片區分,list使用A[2:],是可以得到新的一個list的)

二、tensor的stride()屬性、storage_offset()屬性

  為了更好的解釋tensor的reshape(),以及view()的操作,我們還需要了解下tensor的stride屬性。剛才上面我們提到了,tensor為了節約內存,很多操作其實都是在更改tensor的頭信息區(Tensor),因為頭信息區里面包含了如何組織數據,以及從哪里開始組織。其中stride()和storage_offset()屬性分別代表的就是步長以及初始偏移量。

storage_offset()屬性

  表示tensor的第一個元素與真實存儲區(storage)的第一個元素的偏移量。例如下面的例子:

import torch
A = torch.arange(5) 
B = A[2:]
C = A[1:]
print(A)
print(B)
print(C)
tensor([0, 1, 2, 3, 4]) tensor([2, 3, 4]) tensor([1, 2, 3, 4])
print(B.storage_offset())
print(C.storage_offset())
2
1

  我們可以看到tensorB和tensorC都是從A切片而來的,她們倆的存儲區 (Storage)是和A共享的,只不過B的第一個元素,與存儲區 (Storage)的首元素相差了2個位置(也就是儲存區的index=2開始),C的第一個元素與存儲區 (Storage)的首元素相差了1個位置。

stride()屬性

  這個屬性比較難理解,直接翻譯官方文檔就是:stride是在指定維度dim中從一個元素跳到下一個元素所必需的步長。直接上例子:

import torch
A = torch.rand(2, 3)  # 生成2*3的隨機數
print(A)
print(A.storage())    # 打印A的儲存區真實的數據
打印A: tensor([[0.8438, 0.2782, 0.9584], [0.2089, 0.0259, 0.3666]]) 0.8437800407409668
 0.2781521677970886
 0.9583932757377625
 0.2088671326637268
 0.025857746601104736
 0.366576611995697 [torch.FloatStorage of size 6]
print(A.stride())
(3, 1)

  主要是理解這個(3,1)指的是什么意思。這里的3指的是A[i][j]到A[i+1][j]這兩個數字在存儲區真實數據排列中是相差3的(例如A[0][0]=0.8438與A[1][0]=0.2089這兩個數字在儲存區中位次相差了3);這里的1是指A[i][j]與A[i][j+1]這兩個數字在儲存區的真實數據排列中相差1(例如A[0][0]=0.8438與A[0][1]=0.2781這兩個數字在儲存區中位次相差1)。如果還沒有理解,加下來我們試一下對於3維數據看看他們的stride()屬性。

import torch
A = torch.rand(2, 3, 4)  # 生成2*3*4的隨機數
print("打印A:",A)
print(A.storage())    # 打印A的儲存區真實的數據
打印A: tensor([[[0.4303, 0.7474, 0.8649, 0.5006], [0.2716, 0.9966, 0.7765, 0.6737], [0.5515, 0.2274, 0.9791, 0.1940]], [[0.6401, 0.7746, 0.5124, 0.0258], [0.8576, 0.9118, 0.9504, 0.4675], [0.9359, 0.0687, 0.2457, 0.3604]]]) 0.4302864074707031
 0.747403085231781
 0.8648527264595032
 0.500649631023407
 0.2716004252433777
 0.9965775609016418
 0.7765441536903381
 0.6737198233604431
 0.5515168905258179
 0.2273930311203003
 0.9791405200958252
 0.19399094581604004
 0.6401097774505615
 0.7746065855026245
 0.512383759021759
 0.02578103542327881
 0.8575518727302551
 0.911821186542511
 0.9503545165061951
 0.4674733877182007
 0.9358749389648438
 0.06866037845611572
 0.24573636054992676
 0.3603515625
print(A.stride())
(12, 4, 1)

  輸出有點長,大家對照着看,由於我們A的size是3維度的,因此我們A.stride()也是個三元組,那如果A是4維呢?(A.stride()一定就是4元組了)。這里的12表示就是A[i][j][k]與A[i+1][j][k]這兩個數字在真實儲存區的數據排布中相差12,大家可以對照的找幾個數字試試。同樣的道理這里的4表示A[i][j][k]與A[i][j+1][k]這兩個數字在真實儲存區的數據排布中相差4。最后1表示什么我就不說啦。

  好了終於說完這個很難的知識點了,接下來就進入正題,view()、reshape()、reszie_()三者的關系和區別。

三、view()、reshape()、reszie_()三者的關系和區別

  其中view()和reshape()是官方比較推薦使用的方式,而resize_()官方在文檔中說到不太推薦使用,具體原因一會說到。這三個方法都是可以完成對以一個tensor重新排列,沒錯是重新排列,其實她們本質上都沒有改變tensor的存儲區 (Storage)的真實數據的排列(除了一些特殊情況下會使得存儲區發生改變,這就是她們間的區別)。

view()

  從字面上來說就是"視圖"的意思,就是把存儲區 (Storage)的真實數據,根據某種排列方式”展示“給你看罷了,也就是僅僅改變了頭信息區(Tensor),真實數據的儲存地址是沒有改變的。直接上例子。

import torch
A = torch.arange(6) 
B = A.view(2,3)
print(A)
print(B)
tensor([0, 1, 2, 3, 4, 5]) tensor([[0, 1, 2], [3, 4, 5]])
print(A.storage().data_ptr())
print(B.storage().data_ptr())
1881582170752
1881582170752

  可以看到,A和B的真實數據的內存地址都是一樣的,下面我們進一步打印一下A,B兩個tensor真實數據的排列。

print(A.storage())
print(B.storage())
 0 1
 2
 3
 4
 5 [torch.LongStorage of size 6] 0 1
 2
 3
 4
 5 [torch.LongStorage of size 6]

  可以看到,是完全一樣的。更進一步打印一下A,B的stride()屬性

print(A.stride())
print(B.stride())
(1,) (3, 1)

  沒問題和前面說的是一樣的。

  總結一下,view()函數主要就是更改了tensor中的stride()屬性,這樣從而影響了tensor的顯示,但是從本質上來說A,B還是共用真實數據的存儲區 (Storag)的。

reshape()

  為了解釋view()和reshape()的區別,我們還需要知道一個知識:tensor的連續性。tensor又不是函數哪里來什么連續性?其實tensor的連續性說的就是stride()屬性和size()屬性(tensor維度)之間的關系。

  前一小結已經說了對於一個高維的tensor,stride()指的是:指定維度dim中從一個元素跳到下一個元素所必需的步長。一般來說我們最后一個維度步長應該是1(其實我們前面的例子我們應該也能發現,例子中所有tensor.stride()返回的元組最后一個元素都是1),對吧,因為是按順序排列的嘛。但是當一個tensor涉及到轉置(tensor.t(),tensor.transpose(),tensor.permute())這些操作都會使得tensor失去連續性這個性質。我們直接來看看例子吧。

a = torch.arange(6).view(2, 3)
b = a.t()
c = a.transpose(1,0)
d = a.permute(1,0)

print('b是:',b)
print('c是:',c)
print('d是:',d)
b是: tensor([[0, 3], [1, 4], [2, 5]]) c是: tensor([[0, 3], [1, 4], [2, 5]]) d是: tensor([[0, 3], [1, 4], [2, 5]])
print(a.stride())
print(b.stride())
print(c.stride())
print(d.stride())
(3, 1) (1, 3) (1, 3) (1, 3)

  這里我就不驗證她們是不是同一個存儲區 (Storage)了,大家下來可以驗證下(其實就是同一個)。我們可以看到b,c,d三個tensor的stride()屬性和a是不一樣的,根據stride()的定義大家應該是很容易知道b,c,d返回的stride()是什么意思吧。那為什么說b,c,d的tensor就不連續了呢?是因為她們不滿足張量的連續性條件了。連續性條件如下:

 

   這是什么意思呢?拿b舉例就是,b的stride=(3,1),b的size=(3,2),那么stride[0] != stride[1] * size[1]的,因此b是不滿足連續性條件的。如果從直觀上來感覺來"連續"的意思就是,“我”旁邊的數字就應該是“我”真實儲存區旁邊的數據,例如b[0][0]=0,但是b[0][1]=3,0和3這兩個數字在真實的存儲區 (Storage)不是挨着的啊,所以叫做不連續。

  那不滿足連續性有什么后果呢?后果就是不滿足連續性的tensor是無法使用view()方法的。換句話說,上面例子中的b,c,d都無法再使用view()方法了。

e = b.view(1,6)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

  c,d大家自己下來試一試。所以對於一個tensor是不是連續就意味着他能不能使用view()方法。

  那有什么辦法讓b使用view()方法呢?那就是把b連續化(使用tensor.contiguous()方法)。上例子。

a = torch.arange(6).view(2, 3)
b = a.t()
c = b.contiguous()
print(a.storage())
print(c.storage())
 0 1
 2
 3
 4
 5 [torch.LongStorage of size 6] 0 3
 1
 4
 2
 5 [torch.LongStorage of size 6]
print(a.storage().data_ptr())
print(c.storage().data_ptr())
1881582182144
1881582172928

  其實tensor.contiguous()方法是創造了一個新的tensor(全新的,連存儲區都不共用的tensor),這里的c就是從b得到的連續的tensor了,大家可以打印下c.stride(),會得到(2,1),這樣再根據c的size就能發現,c是滿足上面提到的連續性公式的。

  了解以上知識之后,reshape()和view()的差別就來了,view()是沒法對非連續性的tensor使用的(會報錯),但是reshape()是可以對非連續性tensor使用的。換句話說

  • 當tensor滿足連續性要求時,reshape() = view(),和原來tensor共用內存
  • 當tensor不滿足連續性要求時,reshape() = contiguous() + view(),會產生新的存儲區的tensor,與原來tensor不共用內存

  這就是view()和reshape()的差別了。

reszie_()

  那這一個又和前面那倆有啥關系的呢?從官方文檔上來說,它是不希望我們使用這個resize_()的,如圖。

 

   前面說到的reshape和view都必須要用到全部的原始數據,比如你的原始數據只有12個,無論你怎么變形都必須要用到12個數字們,不能多不能少。因此你就不能把只有12個數字的tensor強行reshape成2*5的維度的tensor。但是resize_()可以做到,無論你存儲區原始有多少個數字,我都能變成你想要的維度,數字不夠怎么辦?隨機產生湊!數字多了怎么辦?就取我需要的部分!上例子。

  多說一句a.resize_()是會改變a的喲,換句話說,a.resize_(2,3)之后,a就不再是1*7的維度了,而是2*3的維度了。但是a的儲存區還是原來的儲存區

a = torch.arange(7)
print("變換前a的儲存區地址:",a.storage().data_ptr())
b = a.resize_(2,3)
print('這是新的a:',a)
變換前a的儲存區地址: 1881579251648 這是新的a: tensor([[0, 1, 2], [3, 4, 5]])
print(a.storage())
print(b.storage())
 0 1
 2
 3
 4
 5
 6 [torch.LongStorage of size 7] 0 1
 2
 3
 4
 5
 6 [torch.LongStorage of size 7]
print('變換后a的儲存區地址',a.storage().data_ptr())
print(b.storage().data_ptr())
變換后a的儲存區地址 1881579251648
1881579251648

  你會發現盡管a的”長相“(數字個數也從7個變成了6個)被改變了,但是存儲區依舊是沒變的(要注意到真實存儲區的個數也沒變喲還是7個),因此我們可以說resize_()再進行變換的時候如果數字多余了,會截取我們需要的數據量,多余的數據量並沒有被舍棄。

  再來看看,當我reszie_多於原來的數據的時候發生什么。

a = torch.arange(7)
print("變換前a的儲存區地址:",a.storage().data_ptr())
b = a.resize_(3,4)
print(a.storage())
print(b.storage())
變換前a的儲存區地址: 1881579250944
0 1 2 3 4 5 6 7667809 6815836 [torch.LongStorage of size 9]
0
1 2 3 4 5 6 7667809 6815836 [torch.LongStorage of size 9]
print('變換后a的儲存區地址',a.storage().data_ptr())
print(b.storage().data_ptr())
變換后a的儲存區地址 1881582026048
1881582026048

  這個時候resize_()前后a的儲存區地址是發生了變化的喲。

  下一個問題:resize_()可不可以對不連續的tensor使用呢?

  答案是可以,並且並不會改變原來tensor的內存。當tensor是不連續的時候,采用reshape()會生成個新的存儲區的,采用resize_()則不會改變存儲區。那這兩者有啥區別呢?其實很好解釋,reshape是尊重tensor,把存儲區改了來將就tensor的reshape的長相,並使得連續。而resize_是:不改存儲區,但是“用戶”又想要看到想看到的長相,行,那我就把存儲區的數按照你想看到的長相排列吧。直接上例子。

import torch
a = torch.arange(6).view(2, 3)
b = a.t()  
#b是這個樣子的:tensor([[0, 3],
#                     [1, 4],
#                     [2, 5]])
c = b.reshape(1,6)
e = b.resize_(1,6)
print("c的存儲區:",c.storage().data_ptr())
print('e的存儲區:',e.storage().data_ptr())
c的存儲區: 2237602017664
e的存儲區: 2237602025472
print("c的存儲區真實數據排布:",c.storage())
print("e的存儲區真實數據排布:",e.storage())
c的存儲區真實數據排布:  
 0
 3
 1
 4
 2
 5
[torch.LongStorage of size 6]
e的存儲區真實數據排布:  
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
print('我是c:',c)
print('我是e:',e)
我是c: tensor([[0, 3, 1, 4, 2, 5]])
我是e: tensor([[0, 1, 2, 3, 4, 5]])

  可以很直觀的看出來,如果tensor是不連續的時候,reshape和resize_的差別了吧。

四、總結

  最后總結一下view()、reshape()、reszie_()三者的關系和區別。

  • view()只能對滿足連續性要求的tensor使用。
  • 當tensor滿足連續性要求時,reshape() = view(),和原來tensor共用內存。
  • 當tensor不滿足連續性要求時,reshape() = contiguous() + view(),會產生新的存儲區的tensor,與原來tensor不共用內存。
  • resize_()可以隨意的獲取任意維度的tensor,不用在意真實數據的個數限制,但是不推薦使用。

  

 

 

參考博客:

PyTorch:view() 與 reshape() 區別詳解_Flag_ing的博客-CSDN博客_reshape和view

pytorch筆記(一)——tensor的storage()、stride()、storage_offset()_Zoran的博客-CSDN博客_pytorch stride()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM