Pytorch取最小或最大的張量索引


Pytorch中根據索引取張量有很多方法,比如index_select和masked_select,和gt,ge等配合食用,但如果需要取出最小幾個或最大幾個張量的索引,則需要動手寫一下

a = torch.tensor([2,3,1,5])
y,_ = torch.sort(a)
mask = a.gt(y[0])
index = []
mask_list = (mask== False).nonzero()
index = [int(i) for i in mask_list]
index

>>>[2]

先對張量做排序,然后求出原張量大於和小於等 最小值y[0]的掩碼,大於為True,小於等於為False,然后用nonezero()方法就可以求出掩碼中False的索引,done


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM