一、_, predicted = torch.max(outputs.data, 1)
那么,這里的 下划線_ 表示什么意思?
首先,torch.max()這個函數返回的是兩個值,第一個值是具體的value(我們用下划線_表示),第二個值是value所在的index(也就是predicted)。
那么,這個 下划線_ 表示的就是具體的value,也就是輸出的最大值。那么為什么用 下划線_,可不可以用其他的變量名稱來代替,比如x?答案自然是可以的。
那么為什么這里選擇用這么特殊的下划線?有沒有特殊含義?這是因為我們不關心最大值是什么,而關心最大值對應的index是什么,所以選用下划線代表不需要用到的變量。比如在圖像分類任務中,值所對應的index就對應着相應的類別class,當我們只關心網絡預測的類別是什么,而不關心該類別的預測概率是多少時,就選擇使用下划線_。
二、這里的數字1表示什么意思?
數字1其實可以寫為dim=1,這里簡寫為1,python也可以自動識別,dim=1表示輸出所在行的最大值,若改寫成dim=0則輸出所在列的最大值。比如說測試集有10個數據,那么訓練好的網絡將會預測這10個數據,得到一個10×2的矩陣(假設是二分類問題,二分類只輸出兩個類別,所以是兩列),比如說預測結果是下面這個矩陣。這里的數字就是,網絡預測為對應類別的概率,而行代表樣本、列代表類別,所以這里應該用dim=1
,因為你需要輸出的每個樣本的預測類別。