pytorch torch.Storage學習


 

tensor分為頭信息區(Tensor)和存儲區(Storage)

信息區主要保存着tensor的形狀(size)、步長(stride)、數據類型(type)等信息,而真正的數據則保存成連續數組,存儲在存儲區

因為數據動輒成千上萬,因此信息區元素占用內存較少,主要內存占用取決於tensor中元素的數目,即存儲區的大小

 

一般來說,一個tensor有着與之相對應的storage,storage是在data之上封裝的接口,便於使用

不同的tensor的頭信息一般不同,但是可能使用相同的storage

生成a:

a = t.arange(0,6)
a.storage()

⚠️將這里改成a = t.arange(0,6).float(),用來保證得到的值的類型為FloatTensor

這跟下面遇見的一個問題相關,可以看到下面了解一下,然后再跟着操作

所以你的下面內容的值的類型應該為FloatTensor類型,我的仍是LongTensor,因為我沒有改過來

返回:

0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]

生成b:

b = a.view(2,3)
b.storage()

返回:

 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]

對比兩者內存地址:

#一個對象的id值可以看作她的內存空間
#a,b storage的內存地址一樣,即是同一個storage
id(a.storage()) == id(b.storage())

返回:

True

改變某個值查看是否共享內存:

#a改變,b也隨之改變,因為他們共享storage,即內存
a[1] = 100 b

返回:

tensor([[  0, 100,   2],
        [  3,   4,   5]])

生成c:

#c從a的后兩個元素取起
c = a[2:] c.storage()#指向相同

返回:

 0
 100
 2
 3
 4
 5
[torch.LongStorage of size 6]

查看其首元素內存地址:

c.data_ptr(), a.data_ptr() #data_ptr返回tensor首元素的內存地址
#從結果可以看出兩者的地址相差16
#因為c是從a第二個元素選起的,每個元素占8個字節,因為a的值的類型是int64

返回:

(140707162378192, 140707162378176)

因為查看后a的類型為int64:

a.dtype

返回:

torch.int64

更改c:

c[0] = -100 #a,c也共享內存空間,c[0]的內存地址對應的是a[2]的內存地址
a

返回:

tensor([   0,  100, -100,    3,    4,    5])

 

使用storage來生成新tensor:

d = t.Tensor(c.storage())#這樣a,b,c,d共享同樣的內存空間
d[0] = 6666 b

⚠️報錯:

RuntimeError: Expected object of data type 6 but got data type 4 for argument #2 'source'

這是因為Tensor期待得到的值的類型是FloatTensor(類型6),而不是其他類型LongTensor(data type 4)

因為如果生成:

dtypea = t.FloatTensor([[1, 2, 3], [4, 5, 6]])
dtypea.storage()

返回:

 1.0
 2.0
 3.0
 4.0
 5.0
 6.0
[torch.FloatStorage of size 6]

再運行就成功了:

d = t.Tensor(dtypea.storage())#這樣a,b,c,d共享同樣的內存空間
d[0] = 6666 dtypea

返回:

tensor([[6.6660e+03, 2.0000e+00, 3.0000e+00],
        [4.0000e+00, 5.0000e+00, 6.0000e+00]])

如果使用的是IntTensor(data type 3),也會報錯:

RuntimeError: Expected object of data type 6 but got data type 3 for argument #2 'source'

ShortTensor(data type 2),CharTensor(data type 1),ByteTensor(data type 0),DoubleTensor(data type 7)

 

下面的操作會在將上面的值改成FloatTensor的基礎上進行,即在a = t.arange(0,6)后面添加.float(),然后從頭執行了一遍

d = t.Tensor(c.storage())#這樣a,b,c,d共享同樣的內存空間
d[0] = 6666 b

返回:

tensor([[ 6.6660e+03,  1.0000e+02, -1.0000e+02],
        [ 3.0000e+00,  4.0000e+00,  5.0000e+00]])

判斷是否共享內存:

#因此a,b,c,d這4個tensor共享storage
id(a.storage()) ==id(b.storage()) ==id(c.storage()) ==id(d.storage())#返回True

偏移量:

#獲取首元素相對於storage地址的偏移量
a.storage_offset(), c.storage_offset(), d.storage_offset()

返回:

(0, 2, 0)

即使使用索引只獲得一部分值,指向仍是storage:

#隔兩行/列取元素來生成e
e = b[::2,::2] print(e) print(e.storage_offset()) id(e.storage()) ==id(a.storage()) #雖然值不相同,但是得到的storage是相同的

返回:

tensor([[6666., -100.]])
0 Out[44]: True

步長信息:是有層次結構的步長

#獲得步長信息
b.stride(), e.stride()

返回:

((3, 1), (6, 2))

查看空間是否連續:

#查看其值的內存空間是否連續
#因為e只取得了storage中的部分值,因此其是不連續的
b.is_contiguous(), e.is_contiguous()

返回:

(True, False)

從上面的操作中我們可以看出絕大多數的操作並不修改tensor的數據,即存儲區的內容,只是修改了頭信息區的內容

這種做法更節省內存,同時提升了處理速度

但是我們可以看見e的操作導致其不連續,這時候可以調用tensor.contiguous()方法將他們變成連續的數據。該方法是復制數據到新的內存中,不再與原來的數據共享storage,如:

e.contiguous().is_contiguous() #返回True

生成f:

print(e.data_ptr())
f = e.contiguous() print(f.data_ptr()) #可見為f新分配了內存空間 print(f) print(f.storage())#內存空間中只有兩個值 print(f.size()) print(e.data_ptr()) #e指向的內存沒有改變 f.is_contiguous() #這里的f的內存空間是連續的

返回:

140707203003760
140707160267104
tensor([[6666., -100.]]) 6666.0 -100.0 [torch.FloatStorage of size 2] torch.Size([1, 2]) 140707203003760 Out[56]: True

是否為連續內存空間有什么影響?
比如當你想要使用.view()轉換tensor的形狀時,如果該tensor的內存空間不是連續的則會報錯:

k = t.arange(0,6).view(2,3).float().t()#進行轉置,轉置后的k內存是不連續的
k.is_contiguous()
k.view(-1)

報錯:

RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /Users/soumith/mc3build/conda-bld/pytorch_1549593514549/work/aten/src/TH/generic/THTensor.cpp:213

報錯的意思也是要求在.view()之前調用.contiguous(),改后為:

k = t.arange(0,6).view(2,3).float().t()#進行轉置,轉置后的k內存是不連續的
k.is_contiguous()
k.contiguous().view(-1)

成功返回:

tensor([0., 3., 1., 4., 2., 5.])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM