函數定義
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
對於給定的輸入 張量input,沿着給定的維度,返回k個最大元素。
一個命名元組(values,indices)將會被返回,這里的indices是返回的元素在原始的input張量中的indices。
函數參數
-
input (Tensor) – the input tensor.
-
k (int) – the k in “top-k”
-
dim (int, optional) – the dimension to sort along,If
dim
is not given, the last dimension of the input is chosen. -
largest (bool, optional) – controls whether to return largest or smallest elements
-
sorted (bool, optional) – controls whether to return the elements in sorted order
- out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
例子
>>> x = torch.arange(1., 6.) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))