關於Pytorch的二維tensor的gather和scatter_操作用法分析


看得不明不白(我在下一篇中寫了如何理解gather的用法)

gather是一個比較復雜的操作,對一個2維tensor,輸出的每個元素如下:

out[i][j] = input[index[i][j]][j]  # dim=0
out[i][j] = input[i][index[i][j]]  # dim=1

二維tensor的gather操作

針對0軸

注意index此時的值

輸入

index = t.LongTensor([[0,1,2,3]])
print("index = \n", index)      #index是2維
print("index的形狀: ",index.shape)  #index形狀是(1,4)  

輸出

index = 
 tensor([[0, 1, 2, 3]])
index的形狀:  torch.Size([1, 4])

分割線============

針對1軸

注意index此時的值

輸入

index = t.LongTensor([[0,1,2,3]]).t()  #index是2維
print("index = \n", index)    #index形狀是(4,1)
print("index的形狀: ",index.shape)

輸出

index = 
 tensor([[0],
        [1],
        [2],
        [3]])
index的形狀:  torch.Size([4, 1])

分割線===========

再來看看幾個例子

注意index在以0軸和1軸為標准時的表達式是不一樣的。
b.gather()中取0維時,輸出的結果是行形式,取1維時,輸出的結果是列形式。

  • b是一個 $ 3\times4 $ 型的
>>> import torch as t
>>> b = t.arange(0,12).view(3,4)
>>> b
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>> index = t.LongTensor([[0,1,2]])

>>> index
tensor([[0, 1, 2]])

>>> b.gather(0,index)     #運行失敗了
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620    

>>> index2 = t.LongTensor([[0,1,2]]).t()

>>> b.gather(1,index2)  #運行成功了
tensor([[ 0],
        [ 5],
        [10]])

>>> index3 = t.LongTensor([[0,1,2,3]]).t()

>>> b.gather(1,index3)  #運行失敗了
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
  • b是一個 $ 6\times6 $ 型的
>>> import torch as t
>>> b = t.arange(0,36).view(6,6)
>>> b
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])

>>> index = t.LongTensor([[0,1,2,3,4,5,6]])
>>> b.gather(0,index)     #運行失敗了
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620  

>>> index = t.LongTensor([[0,1,2,3,4,5]])
>>> b.gather(0,index)    #運行成功了
tensor([[ 0,  7, 14, 21, 28, 35]])
>>> b.gather(1,index)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

>>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()  
>>> b.gather(1,index2)     #運行成功了
tensor([[ 0],     
        [ 7],
        [14],
        [21],
        [28],
        [35]])

>>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
>>> b.gather(1,index3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620  

>>> index4 = t.LongTensor([[0,1,2,3,4]])
>>> b.gather(0,index4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

與gather相對應的逆操作是scatter_,gather把數據從input中按index取出,而scatter_是把取出的數據再放回去。注意scatter_函數是inplace操作。



與gather相對應的逆操作是scatter_,gather把數據從input中按index取出,而scatter_是把取出的數據再放回去。注意scatter_函數是inplace操作。

out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)

根據StackOverflow上的問題修改代碼如下:
輸入

# 把兩個對角線元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())

輸出

tensor([[ 0.,  0.,  0.,  3.],
        [ 0.,  5.,  6.,  0.],
        [ 0.,  9., 10.,  0.],
        [12.,  0.,  0., 15.]])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM