torch topk函數


這個函數是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一個tensor數據
k:指明是得到前k個數據以及其index
dim: 指定在哪個維度上排序, 默認是最后一個維度
largest:如果為True,按照大到小排序; 如果為False,按照小到大排序
sorted:返回的結果按照順序返回
out:可缺省,不要

 

比如,三行兩列,3個樣本,2個類別。

 

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的結果,設置keepdim為True,避免降維。因為topk函數返回的index不降維,shape和輸入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364,  0.7912, -0.3263],
        [-0.8013, -0.9083,  0.7973,  0.1458, -0.9156],
        [-0.2334, -0.0142, -0.5493,  0.0673,  0.8185],
        [-0.4075, -0.1097,  0.8193, -0.2352, -0.9273]])
# indices, shape為 【4,1】,
tensor([[3],   #【0,0】代表 第一個樣本最可能屬於第一類別
        [2],   # 【1, 0】代表第二個樣本最可能屬於第二類別
        [4],
        [2]])
# indices_max等於indices
tensor([[True],
        [True],
        [True],
        [True]])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM