超簡單!pytorch入門教程(五):訓練和測試CNN


我們按照超簡單!pytorch入門教程(四):准備圖片數據集准備好了圖片數據以后,就來訓練一下識別這10類圖片的cnn神經網絡吧。

按照超簡單!pytorch入門教程(三):構造一個小型CNN構建好一個神經網絡,唯一不同的地方就是我們這次訓練的是彩色圖片,所以第一層卷積層的輸入應為3個channel。修改完畢如下:


我們准備了訓練集和測試集,並構造了一個CNN。與之前LeNet不同在於conv1的第一個參數1改成了3

現在咱們開始訓練

我們訓練這個網絡必須經過4步:

第一步:將輸入input向前傳播,進行運算后得到輸出output

第二步:將output再輸入loss函數,計算loss值(是個標量)

第三步:將梯度反向傳播到每個參數

第四步:利用下面公式進行權重更新

新權重w =  舊權重w  +  學習速率𝜂  x 梯度向量g

非常幸運,pytorch幫我們寫好了計算loss的函數和優化的函數。

我們先初始化loss和優化函數:

criterion = nn.CrossEntropyLoss() #叉熵損失函數

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  #使用SGD(隨機梯度下降)優化,學習率為0.001,動量為0.9

待會我們就要用到這兩個函數

假設我們需要對訓練數據完全遍歷兩次,人話就是:我們把所有訓練集的數據扔進去進行訓練,但是扔一次怎么夠呢,扔一次並不能保證我的網絡的參數就訓練的很完美了,那么我們就會反復將訓練集的數據扔進去訓練,每次扔的時候,數據的順序是不一樣的。

這里我們就先扔兩次練練。


訓練網絡

先不管running_loss,它是我們待會用來統計loss的平均值的。

我們先看data,data是從trainloader中枚舉出來的,數據的結構看上面注釋。

我們在訓練前,會將網絡中每個參數的grad值清空為0,這樣做是因為grad值是累加的具體參考pytorch學習筆記(二),設置為0后,每次bp后的grad更新后的值才是正確的。

我們將inputs輸入net之后,得到outputs,將outputs和labels輸入之前定義的叉熵函數計算loss值。除了叉熵方式計算外還有其他計算loss的方法

loss算完后,我們就使用backward向后傳播啦!我們稍微想一下傳播會怎么進行,傳播應該會讓每一個網絡參數的grad值進行更新,我們網絡中的每一個參數都是Variable類型,並且均是葉子節點,grad值必然會進行更新。

接下來,每個參數利用自身的grad值進行梯度下降法的更新就好了,我們利用先前定義好的optimizer使用step()函數進行更新。其他優化方法

好了!講了這么久,我們將代碼下載下來溜溜,看看是什么情況!下載cnn.py

如果沒錯的話,跑完你應該會看到下圖(loss平均值每次跑都會有變化的,因為咱們的loader設置了shuffle=True):


如圖,我們的訓練數據被我們扔進去了兩遍,而且每2000批數據我們打印一次平均loss值,請注意不斷減小的loss值,證明我們的網絡正在被優化啊!!!!

好了,訓練完之后,我們當然我測試一下我們的網絡的分類的正確率到底是多少

上代碼:


測試部分

關於total值我們可以設為10000,因為我們知道訓練集中的圖片數量就是10000,但是為了泛化,我們還是老老實實的點人頭。一開始我們設置correct和total都為0。

我們要計算正確率,就用

正確數/全部數量

我解釋一下第92行代碼,outputs.data是一個4x10張量,max函數會將每一行的最大的那一列的值和序號各自組成一個一維張量返回,第一個是值的張量,第二個是序號的張量。我想還是舉個例子吧:


隨機生成了4x10的tensor,然后max函數會幫我們挑出每一行最大的那個值,比如第一行第10個,第二行是第9個,第三行是第5個,第四行是第10個。而[9,8,5,9]正是表示這些數的位置(從0開始算)

那么為啥輸出的outputs是個4x10的張量呢,我們試着想一下,假設我們現在輸入的是一張圖片,那么出來的是一個10維的特征向量,因為我們同時輸入了4張,所以就是4x10啦!

第93行,我們的labels是4維的向量,size(0)就是4,即沒次total都加4。

第94行,兩個4維向量逐行對比,相同的行記為1,不同的行記為0,再利用sum(),求各元素總和,得到相同的個數。這個不懂也可以自己命令行上試試:


只有第1和第3個元素相同,使用sum的話則會等於2。

完整代碼在這,我們下載后跑一下:


正確率為52%

童鞋們可以看到,訓練2遍結果是不夠好的,才52%!大家可以回去把循環次數2改成10試試。

pytorch入門教程就到這里吧😄,Zen君繼續操練去了。



作者:Zen_君
鏈接:http://www.jianshu.com/p/e4c7b3eb8f3d
來源:簡書
著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM