PyTorch中.view()與.reshape()方法以及.resize_()方法的對比


前言

本文對PyTorch的.view()方法和.reshape()方法還有.resize_()方法進行了分析說明,關於本文出現的view和copy的語義可以看一下我之前寫的文章,傳送門:



torch.Tensor.reshape() vs. torch.Tensor.view()

  • 相同點:從功能上來看,它們的作用是相同的,都是將原張量元素(按順序)重組為新的shape
  • 區別在於:
    • .view()方法只能改變連續的(contiguous)張量,否則需要先調用.contiguous()方法,而.reshape()方法不受此限制
    • .view()方法返回的張量與原張量共享基礎數據(存儲器,注意不是共享內存地址,詳見代碼 ),而.reshape()方法返回的是原張量的copy還是view(即是否跟原張量共享存儲),事先是不知道的,如果可以返回view,那么.reshape()方法返回的就是原張量的view,否則返回的就是copy

–> 因此,為避免語義沖突:

  1. 如果需要原張量的拷貝(copy),就使用.clone()方法
  2. 而如果需要原張量的視圖(view),就使用.view()方法
  3. 如果想要原張量的視圖(view),但是原張量不連續(contiguous),不過原張量擁有兼容的步長(strides),此時可以考慮使用.reshape()函數
a = torch.randint(0, 10, (3, 4))
"""
Out:
tensor([[3, 7, 1, 3],
        [6, 4, 1, 3],
        [8, 8, 5, 7]])
"""

b = a.view(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

c = a.reshape(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

# 非嚴格意義上講,id可以認為是對象的內存地址
print(id(a)==id(b), id(a)==id(c), id(b)==id(c))
"""
前提:python的變量和數據是保存在不同的內存空間中的,PyTorch中的Tensor的存儲也是類似的機制,tensor相當於python變量,保存了tensor的形狀(size)、步長(stride)、數據類型(type)等信息(或其引用),當然也保存了對其對應的存儲器Storage的引用,存儲器Storage就是對數據data的封裝。
viewed對象和reshaped對象都存儲在與原始對象不同的地址內存中,但是它們共享存儲器Storage,也就意味着它們共享基礎數據。
"""
print(id(a.storage())==id(b.storage()), 
	  id(a.storage())==id(c.storage()),
	  id(b.storage())==id(c.storage()))
"""
Out:
False False False
True True True
"""

a[0]=0
print(a, b, c)
"""
Out:
tensor([[0, 0, 0, 0],
        [6, 4, 1, 3],
        [8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
        [1, 3, 8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

c[0]=1
print(a, b, c)
"""
Out:
tensor([[1, 1, 1, 1],
        [1, 1, 1, 3],
        [8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
        [1, 3, 8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
        [1, 3, 8, 8, 5, 7]])
"""

  

 


torch.Tensor.resize_()

torch.Tensor.resize_() 方法的功能跟.reshape() / .view()方法的功能一樣,也是將原張量元素(按順序)重組為新的shape。

當resize前后的shape兼容時,返回原張量的視圖(view);當目標大小(resize后的總元素數)大於當前大小(resize前的總元素數)時,基礎存儲器的大小將改變(即增大),以適應新的元素數,任何新的內存(新元素值)都是未初始化的;當目標大小(resize后的總元素數)小於當前大小(resize前的總元素數)時,基礎存儲器的大小保持不變,返回目標大小的元素重組后的張量,未使用的元素仍然保存在存儲器中,如果再次resize回原來的大小,這些元素將會被重新使用。

(這里說的shape兼容的意思是:resize前后的shape包含的總元素數是一致的,即resize前后的shape的所有維度的乘積是相同的。如resize前,shape為(1, 2 ,3),那resize之后的張量的總元素數需要是1*2*3,故目標shape可以是(2, 3), 可以是(3, 2, 1),可以是(2, 1, 3)等尺寸。)

–> 文字說明有點干燥,看點例子感受一下:

a = torch.arange(24).view(4, 6)
"""
Out:
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])
"""

a.resize_(6, 4)
"""
Out:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])
"""

a.resize_(3, 3)
"""
Out:
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
"""

a.resize_(7, 4)
"""
Out:
tensor([[              0,               1,               2,               3],
        [              4,               5,               6,               7],
        [              8,               9,              10,              11],
        [             12,              13,              14,              15],
        [             16,              17,              18,              19],
        [             20,              21,              22,              23],
        [140720147688480, 140720141167152,               1,               0]])
"""

  

 

ps(官方解釋,不是很能理解): 這是一個底層方法。存儲被重新解釋為c連續的,忽略當前的步長(除非目標大小等於當前大小,在這種情況下張量保持不變)

更多時候應該使用.view() / .reshape() / .set_()方法來替代此方法



參考文獻:

What’s the difference between reshape and view in pytorch?

 

原文鏈接:https://blog.csdn.net/weixin_43002433/article/details/104299896


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM