Pytorch的gather用法理解


先放一張表,可以看成是二維數組

行(列)索引 索引0 索引1 索引2 索引3
索引0 0 1 2 3
索引1 4 5 6 7
索引2 8 9 10 11
索引3 12 13 14 15

看一下下面例子代碼:

針對0維(輸出為行形式)

>>> import torch as t
>>> a = t.arange(0,16).view(4,4)
>>> a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
 
#選取對角線的元素
>>> index = t.LongTensor([[0,1,2,3]])
>>> a.gather(0,index)
tensor([[ 0,  5, 10, 15]])

如何理解結果呢?其實很簡單,就是a.gather(0,index)中第一個0已經表明輸出結果是行形式(0維),如果第一個是1說明輸出結果是列形式(1維),然后按照index = tensor([[0, 1, 2, 3]])順序作用在行上索引依次為0,1,2,3

  • a[0][0] = 0
  • a[1][1] = 5
  • a[2][2] = 10
  • a[3][3] = 15

針對0維

# 選取反對角線上的元素,注意與上面的不同
>>> index2 = t.LongTensor([[3,2,1,0]])
>>> a.gather(0,index2)
tensor([[12,  9,  6,  3]])

如何理解結果呢?同理,按照index = tensor([[3, 2, 1, 0]])順序作用在行上索引依次為3,2,1,0:

  • a[3][0] = 12
  • a[2][1] = 9
  • a[1][2] = 6
  • a[0][3] = 3

針對1維(輸出為列形式)

選取對角線的元素

>>> index3 = t.LongTensor([[0,1,2,3]]).t()
>>> a.gather(1,index3)
tensor([[ 0],
        [ 5],
        [10],
        [15]])

如何理解結果呢?同理,按照index = tensor([[0, 1, 2, 3]])順序作用在列上索引依次為0,1,2,3:

  • a[0][0] = 0
  • a[1][1] = 5
  • a[2][2] = 10
  • a[3][3] = 15

針對1維

選取反對角線上的元素

>>> index4 = t.LongTensor([[3,2,1,0]]).t()
>>> a.gather(1,index4)
tensor([[ 3],
        [ 6],
        [ 9],
        [12]])

如何理解結果呢?同理,按照index = tensor([[3, 2, 1, 0]])順序作用在列上索引依次為3,2,1,0:

  • a[0][3] = 3
  • a[1][2] = 6
  • a[2][1] = 9
  • a[3][0] = 12


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM