pytorch中的scatter_()函數


最近在學習pytorch函數時需要做獨熱碼,然后遇到了scatter_()函數,不太明白意思,現在懂了記錄一下以免以后忘記。

這個函數是用一個src的源張量或者標量以及索引來修改另一個張量。這個函數主要有三個參數scatter_(dim,index,src)

dim:沿着哪個維度來進行索引(一會兒舉個例子就明白了)

index:用來進行索引的張量

src:源張量或者標量

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

這個是官網給出的例子,但是一般在做獨熱碼的時候通常是采用二維張量所以應該是這樣

#dim=0
self[index[x][y]][y]=src[x][y]  

#dim=1
self[x][index[x][y]]=src[x][y]

這個是什么意思呢。首先請看下面的程序,程序是我瞎編的,想試試的話可以自己編數據哈

import torch
x=torch.rand(3,5)
print(x)
print('-------------------')
y=torch.zeros(3,5)
print(y)
print('-------------------')
inx=torch.tensor([[0,4,3,1,2],[3,2,1,4,3]])
output_y=y.scatter_(dim=1,index=inx,src=x)
print(output_y)

下面是運行的結果

tensor([[0.1380, 0.6030, 0.2396, 0.0066, 0.7116],
        [0.5755, 0.2856, 0.4862, 0.2132, 0.2475],
        [0.5145, 0.4753, 0.2736, 0.2623, 0.8532]])
-------------------
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
-------------------
tensor([[0.1380, 0.0066, 0.7116, 0.2396, 0.6030],
        [0.0000, 0.4862, 0.2856, 0.2475, 0.2132],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

Process finished with exit code 0

那么是什么意思呢,舉個例子,這里我要強調一下,index即這個程序中的inx里面的每個數值,不能超過該dim的張量的最大下標,不然的話就會越界,找不到src中的源數據。因為是在dim=1上進行索引,所以采用第二個式子。

我們在索引表中找到index[1][3]=4,那么此時x=1,y=3,即output_y[1][index[1][3]]=src[1][3],即output_y[1][4]=src[1][3]。即x[1][3]。以此類推就可以得到其他的值。

src不僅僅可以是張量,也可以是標量,下面這段代碼是模仿怎么生成獨熱碼

import torch
x=torch.zeros(4,8)
label=torch.tensor([[1],[5],[7],[6]])
one_hot=x.scatter_(1,label,1)
print(one_hot)

其中x的第一個參數代表的是batch_size,第二個參數代表的是classnum,而label有batch_size行只有一列,就是將x每一行的label值指向的位置置成1,這就是獨熱碼。當然其他位置都是0啦,下面看一下結果吧。

tensor([[0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.]])

Process finished with exit code 0

好啦,這就是scatter_()函數的用法。

ps:本來堅持不下去了快,但是把scatter弄清楚了發現還有一點動力學下去,加油吧。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM