gather函數的的官方文檔:
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
例子:
a=t.arange(0,16).view(4,4)
print(a)
index_1=t.LongTensor([[3,2,1,0]])
b=a.gather(0,index_1)
print(b)
index_2=t.LongTensor([[0,1,2,3]]).t()#tensor轉置操作:(a)T=a.t()
c=a.gather(1,index_2)
print(c)
輸出如下:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[12, 9, 6, 3]])
tensor([[ 0],
[ 5],
[10],
[15]])
在上面的例子中,a是一個4×4矩陣:
1)當維度dim=0,索引index_1為[3,2,1,0]時,此時可將a看成1×4的矩陣,通過index_1對a每列進行行索引:第一列第四行元素為12,第二列第三行元素為9,第三列第二行元素為6,第四列第一行元素為3,即b=[12,9,6,3];
2)當維度dim=1,索引index_2為[0,1,2,3]T時,此時可將a看成4×1的矩陣,通過index_1對a每行進行列索引:第一行第一列元素為0,第二行第二列元素為5,第三行第三列元素為10,第四行第四列元素為15,即c=[0,5,10,15]T。
例子二:
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))
輸出:
tensor([[0.1000],
[0.5000]])
總結:
gather函數在提取數據時主要靠dim和index這兩個參數:
dim=1時將input看為n×1階矩陣,index看為k×1階矩陣,取index每行元素對input中每行進行列索引(如:index某行為[1,3,0],對應的input行元素為[9,8,7,6],提取后的結果為[8,6,9]);
同理,dim=0時將input看為1×n階矩陣,index看為1×k階矩陣,取index每列元素對input中每列進行行索引。gather函數提取后的矩陣階數和對應的index階數相同。
參考:https://blog.csdn.net/weixin_44318872/article/details/103183763?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.add_param_isCf&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.add_param_isCf