額 好像是一句非常簡單的代碼 ,但是作為新手 ,我是完全看不懂哎 前十眼。
首先 這里的logits是一個 (a,b)維的張量。其中a是你的全連接輸出維度,b是一個batch中的樣本數量。
我們經過一個argmax的操作,dim=1 意味着找到張量中各自的最大值所在索引。也就是找到每個樣本的全連接輸出中最大的那一個。 最有可能的預測值。
torch.eq會返回一個 batch維的bool值。 。sum統計真值的個數。。float將真值個數化為浮點,。item將得到的浮點數從列表中取出來。
代碼舉例:
結果: