pytorch 中的 gather() 函數詳解


首先,給出官方文檔的鏈接:

https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather


然后,我用白話翻譯一下官方文檔。

gather,顧名思義,聚集、集合。有點像軍訓的時候,排隊一樣,把隊伍按照教官想要的順序進行排列

還有一個更恰當的比喻:gather的作用是根據索引查找,然后講查找結果以張量矩陣的形式返回

1. 拿到一個張量:

  1.  
    import torch
  2.  
    a = torch.arange( 15).view( 3, 5)

a = tensor([

        [ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])

2. 生成一個查找規則:

(張量b的元素都是對應張量a的索引)

  1.  
    b = torch.zeros_like(a)
  2.  
    b[ 1][ 2] = 1
  3.  
    b[ 0][ 0] = 1

b = tensor(

[[1, 0, 0, 0, 0],
 [0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0]])

3. 根據維度dim開始查找:

  1.  
    c = a.gather( 0, b) # dim=0
  2.  
    d = a.gather( 1, b) # dim=1

c= tensor([

        [5, 1, 2, 3, 4],
        [0, 1, 7, 3, 4],
        [0, 1, 2, 3, 4]])

d=tensor([

        [ 1,  0,  0,  0,  0],
        [ 5,  5,  6,  5,  5],
        [10, 10, 10, 10, 10]])

ok, 看到這兒應該有點費勁兒了。

如果dim=0,則b相對於a,它存放的都是第0維度的索引;

如果dim=1,則b相對於a,它存放的都是第1維度的索引;

我舉個栗子,當dim=0時,b[0][0]的元素是1,那么它想要查找a[0][1]中的元素;

dim=1時,b[0][0]的元素是1,那么它想查找的a[1][0]中的元素;

最后的輸出都可以看作是對a的查詢,即元素都是a中的元素,查詢索引都存在b中。輸出大小與b一致。

找一張網圖來描述,這里的index對應b,src對應a,格子里的數值都減1,左圖對應dim=0,右圖對應dim=1。

 

原文鏈接:https://blog.csdn.net/leviopku/article/details/108735704


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM