這個函數是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。
用法
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
input:一個tensor數據
k:指明是得到前k個數據以及其index
dim: 指定在哪個維度上排序, 默認是最后一個維度
largest:如果為True,按照大到小排序; 如果為False,按照小到大排序
sorted:返回的結果按照順序返回
out:可缺省,不要
比如,三行兩列,3個樣本,2個類別。
import torch pred = torch.randn((4, 5)) print(pred) values, indices = pred.topk(1, dim=1, largest=True, sorted=True) print(indices) # 用max得到的結果,設置keepdim為True,避免降維。因為topk函數返回的index不降維,shape和輸入一致。 _, indices_max = pred.max(dim=1, keepdim=True) print(indices_max == indices) # pred tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263], [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156], [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185], [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]]) # indices, shape為 【4,1】, tensor([[3], #【0,0】代表 第一個樣本最可能屬於第一類別 [2], # 【1, 0】代表第二個樣本最可能屬於第二類別 [4], [2]]) # indices_max等於indices tensor([[True], [True], [True], [True]])