对于计算正确率时 logits.argmax(dim=1),torch.eq(pre_label,label)


额  好像是一句非常简单的代码 ,但是作为新手 ,我是完全看不懂哎 前十眼。

  首先 这里的logits是一个 (a,b)维的张量。其中a是你的全连接输出维度,b是一个batch中的样本数量。  

  我们经过一个argmax的操作,dim=1 意味着找到张量中各自的最大值所在索引。也就是找到每个样本的全连接输出中最大的那一个。 最有可能的预测值。

   torch.eq会返回一个 batch维的bool值。 。sum统计真值的个数。。float将真值个数化为浮点,。item将得到的浮点数从列表中取出来。

   

 

代码举例:

A = torch.tensor([[0,1], [0,2], [0,3], [0,5], [0,42], [0,5], [0,0], [1,19]])
# A = torch.tensor([1,2,3,4,5,6,7,8])
B = A.argmax(dim=1)
C=torch.tensor([1,1,1,1,1,1,1,1])
print(B)
print(C)
d=torch.eq(B,C)
print(d)
d=d.sum()
print(d)
d=d.float()
print(d)
d=d.item()
d=torch.eq(B,C).sum().float().item()
print(d)

  

结果:

tensor([1, 1, 1, 1, 1, 1, 0, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([ True,  True,  True,  True,  True,  True, False,  True])
tensor(7)
tensor(7.)
7.0

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM