看得不明不白(我在下一篇中寫了如何理解gather的用法)
gather是一個比較復雜的操作,對一個2維tensor,輸出的每個元素如下:
out[i][j] = input[index[i][j]][j] # dim=0
out[i][j] = input[i][index[i][j]] # dim=1
二維tensor的gather操作
針對0軸
注意index此時的值
輸入
index = t.LongTensor([[0,1,2,3]])
print("index = \n", index) #index是2維
print("index的形狀: ",index.shape) #index形狀是(1,4)
輸出
index =
tensor([[0, 1, 2, 3]])
index的形狀: torch.Size([1, 4])
分割線============
針對1軸
注意index此時的值
輸入
index = t.LongTensor([[0,1,2,3]]).t() #index是2維
print("index = \n", index) #index形狀是(4,1)
print("index的形狀: ",index.shape)
輸出
index =
tensor([[0],
[1],
[2],
[3]])
index的形狀: torch.Size([4, 1])
分割線===========
再來看看幾個例子
注意index在以0軸和1軸為標准時的表達式是不一樣的。
b.gather()中取0維時,輸出的結果是行形式,取1維時,輸出的結果是列形式。
- b是一個 $ 3\times4 $ 型的
>>> import torch as t
>>> b = t.arange(0,12).view(3,4)
>>> b
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> index = t.LongTensor([[0,1,2]])
>>> index
tensor([[0, 1, 2]])
>>> b.gather(0,index) #運行失敗了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
>>> index2 = t.LongTensor([[0,1,2]]).t()
>>> b.gather(1,index2) #運行成功了
tensor([[ 0],
[ 5],
[10]])
>>> index3 = t.LongTensor([[0,1,2,3]]).t()
>>> b.gather(1,index3) #運行失敗了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
- b是一個 $ 6\times6 $ 型的
>>> import torch as t
>>> b = t.arange(0,36).view(6,6)
>>> b
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
>>> index = t.LongTensor([[0,1,2,3,4,5,6]])
>>> b.gather(0,index) #運行失敗了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
>>> index = t.LongTensor([[0,1,2,3,4,5]])
>>> b.gather(0,index) #運行成功了
tensor([[ 0, 7, 14, 21, 28, 35]])
>>> b.gather(1,index)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
>>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()
>>> b.gather(1,index2) #運行成功了
tensor([[ 0],
[ 7],
[14],
[21],
[28],
[35]])
>>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
>>> b.gather(1,index3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
>>> index4 = t.LongTensor([[0,1,2,3,4]])
>>> b.gather(0,index4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
與gather相對應的逆操作是scatter_,gather把數據從input中按index取出,而scatter_是把取出的數據再放回去。注意scatter_函數是inplace操作。
與gather相對應的逆操作是scatter_,gather把數據從input中按index取出,而scatter_是把取出的數據再放回去。注意scatter_函數是inplace操作。
out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)
根據StackOverflow上的問題修改代碼如下:
輸入
# 把兩個對角線元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())
輸出
tensor([[ 0., 0., 0., 3.],
[ 0., 5., 6., 0.],
[ 0., 9., 10., 0.],
[12., 0., 0., 15.]])