實踐:用keras訓練一個MLP用於MNIST數據集
總體思路:
- 數據預處理
- 構建神經網絡模型
- 訓練模型
- 驗證模型的泛化能力,並調節超參數
1.數據預處理
- keras自帶了MNIST數據集的例子,因此使用mnist.load_data讀取數據集。
- x代表數據,y代表標簽,數據的大小為(數據量,28,28)。由於我們采用多層感知機,因此輸入數據應該是1維,即28*28=784,因此首先將數據維度變換為784
- 將數據標准化,每個像素均除以255
- 原數據中標簽為0-9,需要轉化為1-hot標簽。keras提供了一個轉化標簽的函數to_categorical,將標簽轉化為1-hot標簽
- 分配訓練集與驗證集
import keras import numpy as np from keras.datasets import mnist from keras.utils import to_categorical (x_train,y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(-1, 784) x_train = x_train / 255 y_train = to_categorical(y_train) x_test = x_test.reshape(-1, 784) x_test = x_test / 255 y_test = to_categorical(y_test)
2.構建神經網絡模型
Keras提供了兩種建立神經網絡的模型:
- Sequential序貫式:適用於沒有分支結構的神經網絡
- Functional函數式:適用於結構復雜的模型
2.1Sequential方法構建模型
- 導入模塊keras.models.Sequential
- 建立空模型 model=Sequential()
- 逐層加入神經網絡模塊。第一層必須指定輸入的尺寸。
from keras.models import Sequential from keras.layers import Dense model = Sequential() model.add(Dense(300, activation='sigmoid', input_shape=(784,))) model.add(Dense(100, activation='sigmoid')) model.add(Dense(10,activation='softmax'))
2.2 函數式方法
- 導入keras.models.Model模塊
- 定義輸入層並指定大小(Input層)
- 逐個添加層,每層之后加入(x)代表與x指代的層相連。
- 使用Model函數指定模型的輸入和輸出
from keras.models import Model from keras.layers import Input, Dense input_layer = Input((784,)) x = Dense(300, activation='sigmoid')(input_layer) x = Dense(100, activation='sigmoid')(x) x = Dense(10, activation='softmax')(x) model = Model(input_layer, x)
3.模型的訓練
在開始訓練之前,要指定模型訓練的損失函數和參數的更新方法,利用complie方法將其整合到模型中。
- loss:指定損失函數。keras內置了很多損失函數,可以使用字符串形式指定。這里我們是多分類模型,因此使用交叉熵作為損失函數。
- optimizer:指定梯度優化算法。這里我們使用隨機梯度下降法sgd
- metrics:可選參數,與訓練過程無關,僅僅用於訓練中觀察訓練情況。這里我們使用准確率
- 使用fit方法進行訓練。Fit方法需要給定數據、標簽、每批數據大小、訓練輪數等參數
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['acc']) model.fit(x_train, y_train, batch_size=32, epochs=10)
4.模型驗證
- 當我們在測試集中評價我們模型的泛化能力時,使用evaluate,指定數據與真實標簽,輸出在測試集上的損失函數loss值和metric值(准確率)。
- 當我們沒有真實標簽,僅僅用於預測時,使用predict,指定數據。
- 可見,僅僅經過10輪訓練,我們的准確率就達到了90.35%.
- 存儲模型:save(文件名) model.save('my_first_dnn.h5')
- 讀取模型:load_model(文件名)
在keras框架下運行的第一個小程序完成~~~