PyTorch中scatter和gather的用法
閑扯
許久沒有更新博客了,2019年總體上看是荒廢的,沒有做出什么東西,明年春天就要開始准備實習了,雖然不找算法崗的工作,但是還是准備在2019年的最后一個半月認真整理一下自己學習的機器學習和深度學習的知識。
scatter的用法
scatter中文翻譯為散射,首先看一個例子來直觀感受一下這個API的功能,使用pytorch官網提供的例子。
import torch
import torch.nn as nn
x = torch.rand(2,5)
x
tensor([[0.2656, 0.5364, 0.8568, 0.5845, 0.2289],
[0.0010, 0.8101, 0.5491, 0.6514, 0.7295]])
y = torch.zeros(3,5)
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
index
tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
y.scatter_(dim=0,index=index,src=x)
y
tensor([[0.2656, 0.8101, 0.5491, 0.5845, 0.2289],
[0.0000, 0.5364, 0.0000, 0.6514, 0.0000],
[0.0010, 0.0000, 0.8568, 0.0000, 0.7295]])
首先我們可以看到,x的所有值都在y中出現了,且被索引的軸為dim=0,任意一個來自x中的元素,將按照以下公式完成映射。 y[index[i,j],j] = x[i,j],對於x[0,1] = 0.5364,index[0,1] = 1指出這個值將出現在y的第dim=0維,下標為1的位置,因此,y[index[0,1],1] = y[1,1] = x[0,1] = 0.5364.
到這里我們已經對scatter,即散射這個函數有了直觀的認識,可用於將一個矩陣映射到一個矩陣,dim指明要映射的軸,index指明要映射的軸的下標,因此對於3D張量,若調用y.scatter_(dim,index,src),那么有:
y[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
y[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
y[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
最后看一個官方文檔的關於scatter的英文說明:
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
意思和直觀感受幾乎相同,函數可將src映射到目標張量self,在dim維度上,由索引index給出下標,在非dim維度上,直接使用src值所在位置的下標。
self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) ⇐ src.size(d) for all dimensions d, and that index.size(d) ⇐ self.size(d) for all dimensions d != dim.
顯然self,index,src的ndim應該相同了,否則下標越界了,從公式上看index.size(d) > src.size(d),index.size(d) > self.size(d)沒什么問題,index數組可以比src更大,猜測這里是工程上的考慮,因為超出src大小的index數組在這里是沒用的,閑置的空間不會被訪問。
Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.
index所有的值需要在[0,self.size(dim) - 1]區間內,這是必須滿足的,否則越界了。第二句說沿着dim維的index的所有值需要是唯一的,我測試的結果發現可以重復,看下面的代碼:
x = torch.rand(2,5)
x
tensor([[0.6542, 0.6071, 0.7546, 0.4880, 0.1077],
[0.9535, 0.0992, 0.0594, 0.0641, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.6542, 0.0992, 0.0594, 0.4880, 0.1077],
[0.0000, 0.6071, 0.0000, 0.0641, 0.0000],
[0.9535, 0.0000, 0.7546, 0.0000, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[0,1,2,0,0]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.9535, 0.0000, 0.0000, 0.0641, 0.7563],
[0.0000, 0.0992, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0594, 0.0000, 0.0000]])
可以看到沿着dim=0軸上重復了5次,分別是(0,0),(1,1),(2,2),(0,0),(0,0),代碼無報錯和警告,只是覆蓋掉了原來的值,可能是文檔沒有更新,但是API更新了。
params:
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
- src (Tensor) – the source element(s) to scatter, incase value is not specified
- value (float) – the source element(s) to scatter, incase src is not specified
值得注意的是value參數,當沒有指明src時,可以指定一個浮點value變量,利用這一點我們實現一個scatter版本的onehot函數。
x = torch.tensor([[1,1,1,1,1]],dtype=torch.float32)
index = torch.tensor([[0,1,2,3,4]],dtype=torch.int64)
y = torch.zeros(5,5,dtype=torch.float32)
x
tensor([[1., 1., 1., 1., 1.]])
y.scatter_(0,index,x)
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
y = torch.zeros(5,5,dtype=torch.float32)
y.scatter_(0,index,1)
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
可以看到指定value=1,和src=[[1,1,1,1,1]]等價。到這里scatter就結束了。
gather的用法
gather是scatter的逆過程,從一個張量收集數據,到另一個張量,看一個例子有個直觀感受。
x = torch.tensor([[1,2],[3,4]])
torch.gather(input=x,dim=1,index=torch.tensor([[0,0],[1,0]]))
tensor([[1, 1],
[4, 3]])
可以猜測到收集過程,根據index和dim將x中的數據挑選出來,放置到y中,滿足下面的公式: y[i,j] = x[i,index[i,j]],因此有y[0,0] = x[0,index[0,0]] = x[0,0] = 1, y[1,0] = x[1,index[1,0]] = x[1,1] = 4,對於3D數據,滿足以下公式:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
到這里gather的用法介紹就結束了,因為gather畢竟是scatter的逆過程,理解了scatter,gather並不需要太多說明。
小結
- scatter可以將一個張量映射到另一個張量,其中一個應用是onehot函數.
- gather和scatter是兩個互逆的過程,gather可用於壓縮稀疏張量,收集稀疏張量中非0的元素。
- 別再荒廢時光了,做不出成果也不能全怪自己的。