15.手寫數字識別-小數據集(load_digits)


1.手寫數字數據集

  • from sklearn.datasets import load_digits
  • digits = load_digits()

 

 

2.圖片數據預處理

  • x:歸一化MinMaxScaler()
  • y:獨熱編碼OneHotEncoder()或to_categorical

 

  • 訓練集測試集划分
  • 張量結構

3.設計卷積神經網絡結構

  • 繪制模型結構圖,並說明設計依據。

先導入相關的包

然后設計模型結構,因為圖片是(8,8)的像素規模,在池化方面,就最多能池化3次

在卷積上,可以多次卷積,我選擇的(3,3)的卷積核,然后卷積的方式是same,即卷積完成后,得到和卷積前同樣大小的結果。

 構建模型結果:

4.模型訓練

  • model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
  • train_history = model.fit(x=X_train,y=y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)

5.模型評價

  • model.evaluate()
  • 交叉表與交叉矩陣
  • pandas.crosstab
  • seaborn.heatmap

 觀察預測值和實際值

 通過交叉矩陣觀察預測值與實際值的符合情況

  通過熱力圖觀察預測值與實際值的符合程度

 

小結:模型構建完成后,得到的精確率在96%左右,說明模型的構建還是有一定的不足,應該更加優化卷積和池化的安排,嘗試不同的卷積核得到的模型

從熱力圖看得出來,在預測數字8的時候容易和數字1混淆,如果是實際應用中則需要加多數字8的樣本,讓模型能夠更好的學習到數字8的特征。也能進行更好的預測。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM