對於計算正確率時 logits.argmax(dim=1),torch.eq(pre_label,label)


額  好像是一句非常簡單的代碼 ,但是作為新手 ,我是完全看不懂哎 前十眼。

  首先 這里的logits是一個 (a,b)維的張量。其中a是你的全連接輸出維度,b是一個batch中的樣本數量。  

  我們經過一個argmax的操作,dim=1 意味着找到張量中各自的最大值所在索引。也就是找到每個樣本的全連接輸出中最大的那一個。 最有可能的預測值。

   torch.eq會返回一個 batch維的bool值。 。sum統計真值的個數。。float將真值個數化為浮點,。item將得到的浮點數從列表中取出來。

   

 

代碼舉例:

A = torch.tensor([[0,1], [0,2], [0,3], [0,5], [0,42], [0,5], [0,0], [1,19]])
# A = torch.tensor([1,2,3,4,5,6,7,8])
B = A.argmax(dim=1)
C=torch.tensor([1,1,1,1,1,1,1,1])
print(B)
print(C)
d=torch.eq(B,C)
print(d)
d=d.sum()
print(d)
d=d.float()
print(d)
d=d.item()
d=torch.eq(B,C).sum().float().item()
print(d)

  

結果:

tensor([1, 1, 1, 1, 1, 1, 0, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([ True,  True,  True,  True,  True,  True, False,  True])
tensor(7)
tensor(7.)
7.0

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM