首先,給出官方文檔的鏈接:
https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
然后,我用白話翻譯一下官方文檔。
gather,顧名思義,聚集、集合。有點像軍訓的時候,排隊一樣,把隊伍按照教官想要的順序進行排列。
還有一個更恰當的比喻:gather的作用是根據索引查找,然后講查找結果以張量矩陣的形式返回。
1. 拿到一個張量:
-
import torch
-
a = torch.arange( 15).view( 3, 5)
a = tensor([
[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
2. 生成一個查找規則:
(張量b的元素都是對應張量a的索引)
-
b = torch.zeros_like(a)
-
b[ 1][ 2] = 1
-
b[ 0][ 0] = 1
b = tensor(
[[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]])
3. 根據維度dim開始查找:
-
c = a.gather( 0, b) # dim=0
-
d = a.gather( 1, b) # dim=1
c= tensor([
[5, 1, 2, 3, 4],
[0, 1, 7, 3, 4],
[0, 1, 2, 3, 4]])
d=tensor([
[ 1, 0, 0, 0, 0],
[ 5, 5, 6, 5, 5],
[10, 10, 10, 10, 10]])
ok, 看到這兒應該有點費勁兒了。
如果dim=0,則b相對於a,它存放的都是第0維度的索引;
如果dim=1,則b相對於a,它存放的都是第1維度的索引;
我舉個栗子,當dim=0時,b[0][0]的元素是1,那么它想要查找a[0][1]中的元素;
當dim=1時,b[0][0]的元素是1,那么它想查找的a[1][0]中的元素;
最后的輸出都可以看作是對a的查詢,即元素都是a中的元素,查詢索引都存在b中。輸出大小與b一致。
找一張網圖來描述,這里的index對應b,src對應a,格子里的數值都減1,左圖對應dim=0,右圖對應dim=1。

原文鏈接:https://blog.csdn.net/leviopku/article/details/108735704
