gather函數


gather(input, dim, index):根據  index,在  dim  維度上選取數據,輸出的  size  與  index  一致

 

# input (Tensor) – 源張量

# dim (int) – 索引的軸

# index (LongTensor) – 聚合元素的下標(index需要是torch.longTensor類型)

# out (Tensor, optional) – 目標張量

for 3D tensor:

                     out[i][j][k] = tensor[index[i][j][k]][j][k]   # dim=0

                     out[i][j][k] = tensor[i][index[i][j][k]][k]   # dim=1

                     out[i][j][k] = tensor[i][j][index[i][j][k]]   # dim=2

 

for 2D tensor:

 

                        out[i][j] = input[index[i][j]][j]  # dim = 0

                        out[i][j] = input[i][index[i][j]]  # dim = 1

 

import torch as t  # 導入torch模塊
c = t.arange(0, 60).view(3, 4, 5) # 定義tensor
print(c)
index = torch.LongTensor([[[0,1,2,0,2],
                [0,0,0,0,0],
                [1,1,1,1,1]],
                [[1,2,2,2,2],
                 [0,0,0,0,0],
                [2,2,2,2,2]]])
b = t.gather(c, 0, index)
print(b)

輸出:

tensor([[[ 0, 1, 2, 3, 4],
             [ 5, 6, 7, 8, 9],
             [10, 11, 12, 13, 14],
             [15, 16, 17, 18, 19]],

             [[20, 21, 22, 23, 24],
             [25, 26, 27, 28, 29],
             [30, 31, 32, 33, 34],
             [35, 36, 37, 38, 39]],

             [[40, 41, 42, 43, 44],
             [45, 46, 47, 48, 49],
             [50, 51, 52, 53, 54],
             [55, 56, 57, 58, 59]]])

報錯:

Traceback (most recent call last):
File "E:/Release02/my_torch.py", line 14, in <module>
b = t.gather(c, 0, index)
RuntimeError: Size does not match at dimension 1 get 4 vs 3

(第1維尺寸不匹配)

將index調整為:

index = t.LongTensor([[[0, 1, 2, 0, 2], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [1, 1, 1, 1, 1]]])
則上文輸出為:

tensor([[[ 0, 21, 42, 3, 44],
             [ 5, 6, 7, 8, 9],
             [30, 31, 32, 33, 34],
             [35, 36, 37, 38, 39]],

            [[20, 41, 42, 43, 44],
             [ 5, 6, 7, 8, 9],
             [50, 51, 52, 53, 54],
             [35, 36, 37, 38, 39]],

            [[20, 41, 42, 43, 44],
             [ 5, 6, 7, 8, 9],
             [50, 51, 52, 53, 54],
             [35, 36, 37, 38, 39]]])

對於2D tensor 則無“index與tensor 的size一致”之要求,

這個要求在官方文檔和其他博文、日志中均無提到

可能是個坑吧丨可能是個坑吧丨可能是個坑吧

eg:

代碼(此部分來自https://www.yzlfxy.com/jiaocheng/python/337618.html):

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

 

輸出:

 1 2 3
 4 5 6
[torch.FloatTensor of size 2x3]

 1 2
 6 4
[torch.FloatTensor of size 2x2]

 1 5 6
 1 2 3
[torch.FloatTensor of size 2x3]



官方文檔:
torch.gather(input, dim, index, out=None) → Tensor

 Gathers values along an axis specified by dim.

 For a 3-D tensor the output is specified by:

 out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

 Parameters: 

  input (Tensor)-The source tensor
  dim (int)-The axis along which to index
  index (LongTensor)-The indices of elements to gather
  out (Tensor, optional)-Destination tensor

 Example:

 >>> t = torch.Tensor([[1,2],[3,4]])
 >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
  1 1
  4 3
 [torch.FloatTensor of size 2x2]

以上,學習中遇到的問題,記錄方便回顧,亦示他人以之勉坑








免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM