pytorch實現了LeNet網絡,應用CIFAR-10數據集圖片分類


項目介紹

本項目采用pytorch實現了LeNet網絡,應用CIFAR-10數據集圖片分類。
代碼:https://github.com/pxlsdz/pytorch-LeNet-CIFAR-10

數據集介紹

CIFAR-10 數據集由 10 個類別的 60000 張 32x32 彩色圖像組成,每個類別有 6000 張圖像。 有 50000 張訓練圖像和 10000 張測試圖像。10個類別分別為:'plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'。

下載地址: https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,或直接允許本項目代碼就可以下載。

代碼邏輯

1.使用torchvision加載並預處理CIFAR-10數據集
2.定義LeNet網絡
3.定義損失函數和優化器
4.訓練網絡並更新網絡參數
5.測試網絡

代碼運行

本文的代碼類型為ipynb文件,可在本地的JupyterNotebook或者在雲端Notebook實驗平台運行如google cloabMo,本項目的google cloab地址 https://drive.google.com/file/d/1rHo7MfPAFrdx9L13e0SIttkPV5KL8T5k/view?usp=sharing

Requirement

  • python3.6+
  • numpy
  • pytorch1.8.1+cuda
  • torchvision 0.9.0
  • matplotlib

超參數分析示例

本項目選用學習率為實驗對象,設置不同的學習率:[0.1, 0.05, 0.01](因為資源和時間限制),探究不同學習率對模型的影響。不同學習率的迭代次數增加的損失率如下圖:

下載 (1)

隨着迭代次數的增加,模型的損失一直在減少,說明模型已經學習到一定內容。從圖中可見,當學習率lr==0.01的時候,即圖中的綠色的曲線,收斂的速度越慢,但是損失最小,准確度最高;當學習率lr ==0.1的時候,即圖中的藍色的曲線,收斂的速度最快,但損失最大,准確度最低;驗證了課堂的學習知識,如果學習速率lr太小,梯度下降收斂速度會很慢;如果學習速率lr太大,損失函數的值在每次迭代后不一定能下降,算法最后可能會發散,很難達到最優值。最好的方法是設定自動更新學習率的方法,讓模型自適應地調整學習率,PyTorch已經為我們封裝好了一些在訓練過程中動態調整學習率的方法:optim.lr_scheduler官方文檔


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM