pytorch.topk()用於返回Tensor中的前k個元素以及元素對應的索引值。例:
import torch item=torch.IntTensor([1,2,4,7,3,2]) value,indices=torch.topk(item,3) print("value:",value) print("indices:",indices)
輸出結果為:
其中:value中存儲的是對應的top3的元素,並按照從大到小的取值方式進行存儲
indices中存儲的是value中top3元素在原Tensor中的索引值