torch.max()函數


一、_, predicted = torch.max(outputs.data, 1)

那么,這里的 下划線_ 表示什么意思?

首先,torch.max()這個函數返回的是兩個值,第一個值是具體的value(我們用下划線_表示),第二個值是value所在的index(也就是predicted)。

那么,這個 下划線_ 表示的就是具體的value,也就是輸出的最大值。那么為什么用 下划線_,可不可以用其他的變量名稱來代替,比如x?答案自然是可以的。

那么為什么這里選擇用這么特殊的下划線?有沒有特殊含義?這是因為我們不關心最大值是什么,而關心最大值對應的index是什么,所以選用下划線代表不需要用到的變量。比如在圖像分類任務中,值所對應的index就對應着相應的類別class,當我們只關心網絡預測的類別是什么,而不關心該類別的預測概率是多少時,就選擇使用下划線_。

二、這里的數字1表示什么意思?

數字1其實可以寫為dim=1,這里簡寫為1,python也可以自動識別,dim=1表示輸出所在行的最大值,若改寫成dim=0則輸出所在列的最大值。比如說測試集有10個數據,那么訓練好的網絡將會預測這10個數據,得到一個10×2的矩陣(假設是二分類問題,二分類只輸出兩個類別,所以是兩列),比如說預測結果是下面這個矩陣。這里的數字就是,網絡預測為對應類別的概率,而行代表樣本、列代表類別,所以這里應該用dim=1,因為你需要輸出的每個樣本的預測類別。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM