該案例主要目的是為了熟悉Keras基本用法,以及了解DNN基本流程。
示例代碼:
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.datasets import mnist
from keras.layers import Dense
from keras.utils.np_utils import to_categorical
#加載數據,訓練60000條,測試10000條,X_train.shape=(60000,28,28)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
#特征扁平化,縮放,標簽獨熱
X_train_flat = X_train.reshape(60000, 28*28)
X_test_flat = X_test.reshape(10000, 28*28)
X_train_norm = X_train_flat / 255
X_test_norm = X_test_flat / 255
y_train_onehot = to_categorical(y_train, 10) #shape為(60000,10)
y_test_onehot = to_categorical(y_test, 10) #shape為(10000,10)
#構建模型
model = Sequential()
model.add(Dense(100, activation='relu', input_shape=(28*28,)))
model.add(Dense(50, activation='relu'))
model.add(Dense(10, activation='softmax'))
#模型配置和訓練
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train_norm, y_train_onehot, epochs=5, batch_size=32, verbose=1)
print("訓練完畢!")
訓練結果為:
繼續在測試集上評估模型。
#測試集上評估表現
score = model.evaluate(X_test_norm, y_test_onehot)
print("在測試集上評估完畢!")
print("在測試集上表現:Loss={:.4f}, Accuracy={:.4f}".format(score[0], score[1]))
#在測試集上預測
y_pred_class = model.predict_classes(X_test_norm) #shape=(10000,)
print("預測完畢!")
#查看預測效果,隨機查看多張圖片
idx = 22 #隨機設置
count = 0
fig1 = plt.figure(figsize = (10,7))
for i in range(3):
for j in range(5):
count += 1
ax = plt.subplot(3,5,count)
plt.imshow(X_test[idx+count])
ax.set_title("predict:{} label:{}".format(y_pred_class[idx+count],
y_test[idx+count]))
fig1.savefig('images/look.jpg')
運行結果為:
為了了解模型預測錯誤原因,可查看預測錯誤的圖片。
#找出錯誤所在
X_test_err = X_test[y_test!=y_pred_class] #(num_errors, 28, 28)
y_test_err = y_test[y_test!=y_pred_class] #(num_errors,)
y_pred_class_err = y_pred_class[y_test!=y_pred_class]
#連續查看多張錯誤圖片
idx = -1
count = 0
fig2 = plt.figure(figsize = (10,7))
for i in range(3):
for j in range(5):
count += 1
ax = plt.subplot(3,5,count)
plt.imshow(X_test_err[idx+count])
ax.set_title("predict:{} label:{}".format(y_pred_class_err[idx+count],
y_test_err[idx+count]))
fig2.savefig('images/errors.jpg')
運行結果為: