Pytorch取最小或最大的张量索引


Pytorch中根据索引取张量有很多方法,比如index_select和masked_select,和gt,ge等配合食用,但如果需要取出最小几个或最大几个张量的索引,则需要动手写一下

a = torch.tensor([2,3,1,5])
y,_ = torch.sort(a)
mask = a.gt(y[0])
index = []
mask_list = (mask== False).nonzero()
index = [int(i) for i in mask_list]
index

>>>[2]

先对张量做排序,然后求出原张量大于和小于等 最小值y[0]的掩码,大于为True,小于等于为False,然后用nonezero()方法就可以求出掩码中False的索引,done


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM