1.導入MNIST數據集
直接使用fetch_mldata會報錯,錯誤信息是python3.7把fetch_mldata方法移除了,所以需要單獨下載數據集
從這個網站上下載數據集:
https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat
使用如下方法獲取路徑:
from sklearn.datasets.base import get_data_home print (get_data_home()) # 如我的電腦上的目錄為: C:\Users\Mr.Wmn\scikit_learn_data\mldata #下載好mnist-original.mat數據集放到獲取的路徑里,在輸入如下內容便不會報錯了 #導入數據集拆分工具 from sklearn.model_selection import train_test_split #導入數據集獲取工具 from sklearn.datasets import fetch_mldata #導入MLP神經網絡 from sklearn.neural_network import MLPClassifier #導入numpy import numpy as np #加載MNIST手寫數字數據集 mnist = fetch_mldata('MNIST original') mnist
{'DESCR': 'mldata.org dataset: mnist-original', 'COL_NAMES': ['label', 'data'], 'target': array([0., 0., 0., ..., 9., 9., 9.]), 'data': array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}
print('\n\n\n') print('代碼運行結果') print('====================================\n') #打印樣本數量和樣本特征數 print('樣本數量:{} 樣本特征數:{}'.format(mnist.data.shape[0],mnist.data.shape[1])) print('\n====================================') print('\n\n\n')
代碼運行結果 ==================================== 樣本數量:70000 樣本特征數:784 ====================================
#建立訓練數據集和測試數據集 X = mnist.data/255 y = mnist.target X_train,X_test,y_train,y_test = train_test_split(X,y,train_size=1000,random_state=62) print(X_train,X_test,y_train,y_test)
[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]] [[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]] [5. 7. 7. 0. 0. 4. 4. 5. 1. 2. 8. 7. 4. 8. 2. 3. 9. 7. 2. 5. 9. 7. 9. 6. 7. 1. 1. 3. 2. 6. 9. 4. 3. 1. 8. 3. 2. 0. 8. 0. 9. 7. 2. 9. 7. 9. 5. 1. 1. 0. 8. 2. 5. 6. 5. 2. 2. 8. 1. 6. 0. 5. 9. 9. 6. 5. 4. 3. 1. 7. 5. 9. 2. 4. 2. 3. 2. 2. 8. 1. 8. 9. 1. 3. 3. 4. 7. 7. 9. 8. 5. 0. 3. 0. 1. 0. 7. 5. 8. 2. 5. 8. 3. 7. 0. 1. 9. 4. 6. 7. 0. 7. 4. 0. 6. 0. 3. 4. 0. 6. 9. 0. 6. 1. 7. 0. 3. 9. 7. 3. 8. 0. 8. 8. 5. 6. 1. 2. 3. 9. 1. 9. 2. 1. 3. 4. 9. 1. 0. 8. 6. 3. 1. 8. 5. 0. 2. 0. 6. 5. 6. 3. 1. 3. 1. 1. 1. 5. 6. 1. 9. 2. 8. 5. 7. 8. 5. 9. 1. 2. 8. 1. 0. 9. 4. 2. 4. 2. 1. 8. 8. 8. 1. 2. 9. 7. 8. 4. 1. 6. 8. 6. 7. 8. 7. 2. 5. 2. 5. 8. 1. 4. 3. 0. 6. 3. 5. 8. 6. 9. 4. 9. 1. 6. 6. 6. 3. 0. 9. 2. 4. 2. 2. 4. 3. 9. 2. 5. 8. 4. 1. 0. 8. 1. 8. 5. 3. 1. 8. 7. 7. 9. 6. 9. 3. 7. 8. 9. 1. 1. 5. 8. 7. 7. 0. 0. 5. 7. 6. 2. 5. 4. 4. 4. 3. 0. 5. 0. 8. 1. 8. 5. 2. 6. 2. 9. 9. 2. 8. 3. 8. 0. 9. 4. 8. 5. 7. 7. 0. 1. 7. 2. 4. 2. 5. 3. 7. 0. 4. 4. 9. 1. 9. 0. 4. 3. 7. 4. 5. 8. 0. 8. 1. 2. 9. 2. 1. 3. 3. 0. 5. 3. 2. 8. 7. 6. 6. 6. 4. 3. 1. 8. 4. 0. 3. 9. 2. 1. 7. 0. 5. 2. 5. 4. 3. 5. 6. 9. 4. 7. 4. 4. 2. 4. 1. 1. 3. 1. 3. 8. 5. 7. 5. 0. 1. 8. 2. 8. 2. 2. 8. 7. 1. 6. 7. 7. 1. 7. 4. 0. 5. 1. 8. 5. 1. 9. 5. 6. 8. 6. 1. 7. 9. 0. 9. 2. 3. 6. 2. 4. 8. 2. 6. 8. 1. 9. 5. 0. 7. 5. 8. 2. 0. 5. 4. 3. 1. 8. 8. 7. 8. 9. 6. 6. 0. 9. 3. 9. 8. 9. 0. 5. 0. 6. 0. 1. 9. 3. 0. 3. 9. 8. 0. 6. 5. 3. 4. 8. 5. 3. 9. 5. 8. 4. 3. 7. 1. 4. 8. 9. 6. 4. 9. 1. 0. 3. 2. 8. 4. 1. 4. 9. 7. 5. 8. 2. 6. 0. 2. 8. 1. 2. 6. 6. 0. 8. 4. 7. 5. 7. 8. 9. 0. 0. 9. 2. 6. 4. 0. 3. 6. 9. 1. 1. 1. 9. 8. 4. 1. 6. 5. 4. 1. 3. 0. 0. 4. 7. 7. 3. 5. 3. 6. 0. 1. 9. 6. 3. 2. 2. 5. 9. 2. 7. 5. 1. 0. 1. 3. 9. 0. 4. 3. 6. 7. 5. 7. 5. 9. 3. 3. 4. 8. 1. 8. 0. 1. 2. 9. 8. 3. 6. 3. 0. 7. 1. 3. 2. 2. 9. 7. 8. 0. 6. 5. 6. 1. 5. 3. 4. 5. 1. 9. 4. 9. 6. 6. 7. 4. 2. 4. 5. 8. 9. 1. 9. 6. 7. 7. 0. 0. 6. 9. 0. 6. 0. 5. 2. 8. 9. 8. 1. 4. 3. 0. 6. 5. 4. 6. 9. 7. 1. 9. 0. 5. 4. 7. 6. 0. 5. 0. 3. 0. 0. 0. 1. 4. 0. 7. 5. 0. 9. 5. 3. 4. 9. 9. 7. 6. 0. 6. 1. 3. 5. 7. 5. 2. 9. 6. 0. 5. 7. 8. 5. 0. 9. 2. 8. 1. 7. 0. 8. 7. 7. 7. 7. 5. 5. 5. 7. 2. 1. 9. 9. 7. 2. 9. 4. 0. 4. 8. 3. 3. 4. 9. 3. 7. 0. 2. 9. 8. 8. 7. 5. 8. 7. 9. 0. 6. 9. 8. 6. 1. 7. 2. 8. 9. 8. 2. 4. 7. 1. 8. 8. 7. 3. 1. 8. 8. 9. 7. 4. 7. 7. 1. 1. 8. 8. 2. 7. 1. 7. 6. 3. 7. 6. 5. 2. 3. 7. 2. 0. 7. 3. 9. 8. 0. 0. 5. 4. 4. 2. 9. 9. 5. 8. 4. 7. 4. 8. 5. 8. 3. 1. 7. 4. 8. 9. 2. 3. 8. 7. 3. 2. 7. 2. 6. 7. 7. 9. 1. 0. 8. 6. 6. 9. 4. 7. 9. 3. 1. 4. 6. 0. 8. 6. 5. 1. 8. 2. 1. 5. 3. 7. 1. 2. 5. 4. 6. 2. 6. 3. 2. 1. 7. 6. 8. 6. 3. 8. 3. 9. 0. 4. 2. 2. 8. 8. 3. 7. 8. 4. 3. 5. 3. 2. 2. 8. 0. 1. 0. 9. 4. 3. 6. 1. 6. 1. 2. 3. 3. 4. 0. 0. 7. 2. 6. 0. 3. 7. 2. 6. 4. 6. 6. 3. 9. 5. 6. 5. 8. 1. 3. 7. 8. 8. 9. 0. 1. 9. 3. 4. 1. 4. 1. 1. 7. 4. 8. 5. 8. 1. 3. 6. 3. 8. 5. 9. 0. 6. 4. 8. 0. 3. 3. 9. 1. 0. 4. 1. 3. 4. 6. 4. 9. 2. 8. 4. 0. 5. 3. 9. 7. 0. 9. 7. 8. 6. 7. 7. 6. 8. 9. 5. 1. 1. 7. 4. 5. 9. 8. 5. 1. 1. 7. 3. 1. 9. 9. 9. 3. 8. 2. 9. 7. 1. 7. 1. 1. 4. 3. 1. 1. 3. 0. 1. 3. 3. 4. 5. 8. 1. 2. 0. 4. 6. 7. 1. 2. 1.] [8. 2. 5. ... 3. 6. 9.]
2.訓練MLP神經網絡
#設置神經網絡有兩個100個節點的隱藏層 mlp_hw = MLPClassifier(solver='lbfgs',hidden_layer_sizes=[100,100],activation='relu',alpha = 1e-5,random_state=62) #使用數據訓練神經網絡模型 mlp_hw.fit(X_train,y_train) print('\n\n\n') print('代碼運行結果') print('====================================\n') #打印模型分數 print('測試數據集得分:{:.2f}% '.format(mlp_hw.score(X_test,y_test)*100)) print('\n====================================') print('\n\n\n')
代碼運行結果 ==================================== 測試數據集得分:85.79% ====================================
3.使用模型進行數字識別
#導入圖像處理工具 from PIL import Image #打開圖像 image = Image.open("4.png").convert('F') #調整圖像的大小 image = image.resize((28,28)) arr = [] #將圖像中的像素作為預測數據點的特征 for i in range(28): for j in range(28): pixel = 1.0 - float(image.getpixel((j,i)))/255. arr.append(pixel) #由於只有一個樣本,所以需要進行reshape操作 arr1 = np.array(arr).reshape(1,-1) #進行圖像識別 print('圖片中的數字是:{:.0f}'.format(mlp_hw.predict(arr1)[0]))
圖片中的數字是:5
總結 :
scikit-learn中的MLP分類和回歸在易用性表現不錯,但是僅限於處理小數據集,對於更龐大或更復雜的數據集就不怎么友好了.
在計算能力充足並且參數設置合適的情況下,神經網絡可以比其他的機器學習算法表現更加優異
其缺點也很明顯,如模型訓練時間相對更長,對數據預處理的要求較高.
文章引自 : 《深入淺出python機器學習》