在SSD的代碼中經常有見到如下的操作:
_, idx = flt[:, :, 0].sort(1, descending=True)#大小為[batch size, num_classes*top_k] _, rank = idx.sort(1)#再對索引升序排列,得到其索引作為排名rank
其作用是什么呢?舉個例子:
import torch a = torch.randn(3,4) print(a) print() i, idx = a.sort(dim=1, descending=True) print(i) print(idx) print() j, rank = idx.sort(dim=1) print(rank)
返回:
tensor([[ 2.3326, 0.0275, -0.0799, 0.4156], [-2.2066, 1.7997, -2.2767, 0.4704], [-0.6980, 0.2285, 1.0018, -0.8874]]) tensor([[ 2.3326, 0.4156, 0.0275, -0.0799], [ 1.7997, 0.4704, -2.2066, -2.2767], [ 1.0018, 0.2285, -0.6980, -0.8874]]) tensor([[0, 3, 1, 2], [1, 3, 0, 2], [2, 1, 0, 3]]) tensor([[0, 2, 3, 1], [2, 0, 3, 1], [2, 1, 0, 3]])
其實就是可以通過最后的rank看出對應位置的值的排序位置,這里是降序,所以索引0表示最高
以rank的第一行第一列的值0為例,其表示對應的a中第一行第一列的值2.3326是第一行中最大的,因為設置了dim=1;第一行第三列的值3為例,其表示對應的a中第一行第三列的值-0.0799是第一行中最小的
起到了對應位置排名的作用
