參考資料:
https://pytorch.org/docs/stable/index.html
深度學習里,很多時候我們只想取輸出中的一部分值,此時便用上了Pytorch中的高級索引函數。我們見過最多的可能就是torch.gather這個函數了。這個隨筆講解一下Pytorch中的高級選擇函數。
一、torch.index_select
torch.index_select(input, dim, index, *, out=None)-> Tensor """ :param input(Tensor) - the input tensor :param dim(int) - the dimension in which we index :param index(IntTensor or LongTensor) - the 1-D tensor containing the indices to index :output out(Tensor,optional) - the output tensor """
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])
從這個官方示例可以看出,torch.index是針對某一個維度進行選擇的。在示例代碼中選擇了dim=0(行)。然后使用一個可變長度的indices,選取indices對應的行數。
二、torch.masked_select
這個函數就更為簡單粗暴了,從名字就可以看出來它是使用一個蒙版來選擇Tensor中的值。也容易想到這個mask tensor需要和input tensor的形狀保持一致。但是需要注意的是這個函數的輸出是一個一維的Tensor,保存了從原始的Tensor中選擇出來的所有值。
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], [-1.2035, 1.2252, 0.5002, 0.6248], [ 0.1307, -2.0608, 0.1244, 2.0139]]) >>> mask = x.ge(0.5) >>> mask tensor([[False, False, False, False], [False, True, True, True], [False, False, False, True]]) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
小知識點是這里使用了一個ge函數。該函數會逐元素比較array中的值和給定值的大小。然后返回布爾類型的tensor。總之要想用好masked_select,和各種能判斷並生成布爾類型tensor的函數搭配起來才是正道。
三、torch.gather
這個函數有點繞,建議直接去看官方的文檔,說明的比較清楚。
https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
我對於這個函數的理解是,首先有一個src_tensor和一個index_tensor,index_tensor和src_tensro有相同的維度數,但是在每一個維度上,應該有d(index_tensor) <= d(src_tensor)。
首先看最簡單的情況即index_tensor和src_tensor有完全相同的形狀。按照文檔上的說明,白話式的解釋就是:對於輸出矩陣的每一個位置,數據源還是來自src_tensor。只是在指定的那一個維度上,只取響應index的值。拿這個范例來說,[0, 0 ]這個位置,我們的數據源還是src_tensor,只不過數據源變廣了,變成了src_tensor[0,:]即第一行中所有列的元素。然后再看index_tensor,這個位置是0。於是這個位置就被賦成餓了src_tensor[0, 0]。
關於index_tensor比src_tensor小的情況。我最開始的理解就是使用了廣播機制首先變換到一樣的形狀上,但是這是錯誤的,下面做一個和官方示例相似的小實驗測試一下。
>>>t = torch.tensor([[1, 2], [3, 4]]) >>>torch.gather(t, 1, torch.tensor([[0, 0]])) tensor([[1, 1]])
我們發現輸出的結果其實是和index_tensor的形狀一致的。也就是舍棄掉index_tensor沒有覆蓋的位置。