一、預處理部分
1.拿到數據首先對數據進行分析
對數據的分布有一個大致的了解,可以用畫圖函數查看所有類的分布情況。可以采取刪除不合理類的方法來提高准確率;
對圖像進行分析,在自定義的圖像增強的多種方式中,嘗試對圖像進行變換,看是否存在主觀上的特征增強,具體的增強
方法在aug.py文件中,可以在線下對數據進行測試,看是否在增強后對結果有好的影響。
2.模型的選取
依據新模型效果較好的原則,盡量選取已存在的最新模型,可以選取進幾年再imagenet比賽上取得最好的效果的幾種模型
分別進行測試,目前效果最好的模型是resnet(深度殘差網絡),是卷積神經網絡的最新發展;
但僅僅單模型的效果肯定是不如多模型綜合的效果好的,所以可以選取效果較好的幾種模型,最后按其權重進行加權平均
來獲取最終的預測結果;
始終要注意的一點是,模型是次要的,最主要最核心的問題還是在於對於數據的處理。
3.處理數據
對數據圖像進行增強,不管是使用pytorch自帶的transform模塊,還是自定義的數據增強處理方式,都要對數據進行合理的
改變,最基本的改變是對圖像進行簡單的隨機翻轉、切割、旋轉等,還有要注意的一點是需要改變圖像的尺寸,以適應模型
的輸入要求。
本次比賽數據進行的增強方式有:
- RandomRotation(30)
- RandomHorizontalFlip()
- RandomVerticalFlip()
- RandomAffine(45)
4.超參數的設置
對於整體代碼中所需要的超參數進行單獨處理,設置在一個文件中,使用時候直接調用即可。
二、輸入數據進入模型進行訓練
1.划分數據集
首先根據所給文件把每個類的圖像都分類到各自的文件夾中去,模型的輸入要求類型基本都是這樣,然后對於數據集划分為
訓練集、測試集、驗證集,分別在模型的訓練、測試階段使用。
2.模型訓練
根據pytorch的模型訓練過程,輸入訓練集,對模型進行訓練,每個epoch后對模型進行評價,在整個epoch結束后,得到最好
的模型。
3.測試階段
把測試集輸入保存的最好模型中去,得到輸出結果,進行分析。
三、pytorch中的訓練模塊化
1.加載模型
2.優化器和loss函數的設置
3.訓練集加載入pytorch的數據加載類Dataloader中,以便於調用
4.開始每個epoch的訓練,輸入,目標,loss,歸零,反向傳播,開始
5.評估模型,得出最優模型
參考大神chaojiezhu的github。
https://github.com/spytensor/plants_disease_detection