(predicted == labels).sum().item()作用


 ⚠️(predicted == labels).sum().item()作用,舉個小例子介紹:

# -*- coding: utf-8 -*-
import torch import numpy as np data1 = np.array([ [1,2,3], [2,3,4] ]) data1_torch = torch.from_numpy(data1) data2 = np.array([ [1,2,3], [2,3,4] ]) data2_torch = torch.from_numpy(data2) p = (data1_torch == data2_torch) #對比后相同的值會為1,不同則會為0 print p print type(p) d1 = p.sum() #將所有的值相加,得到的仍是tensor類別的int值 print d1 print type(d1) d2 = d1.item() #轉成python數字 print d2 print type(d2)

返回:

(deeplearning2) userdeMBP:pytorch user$ python test.py
tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.uint8) <class 'torch.Tensor'> tensor(6) <class 'torch.Tensor'> 6 <type 'int'>

 

即如果有不同的話,會變成:

# -*- coding: utf-8 -*-
import torch import numpy
as np data1 = np.array([ [1,2,3], [2,3,4] ]) data1_torch = torch.from_numpy(data1) data2 = np.array([ [1,2,3], [4,5,6] ]) data2_torch = torch.from_numpy(data2) p = (data1_torch == data2_torch) print p print type(p) d1 = p.sum() print d1 print type(d1) d2 = d1.item() print d2 print type(d2)

返回:

(deeplearning2) userdeMBP:pytorch user$ python test.py
tensor([[1, 1, 1],
        [0, 0, 0]], dtype=torch.uint8)
<class 'torch.Tensor'>
tensor(3)
<class 'torch.Tensor'>
3
<type 'int'>

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM