莫煩大大keras的Mnist手寫識別(5)----自編碼


一、步驟:

  1. 導入包和讀取數據

  2. 數據預處理

  3. 編碼層和解碼層的建立 + 構建模型

  4. 編譯模型

  5. 訓練模型

  6. 測試模型【只用編碼層來畫圖】

二、代碼:

1、導入包和讀取數據

#導入相關的包
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.models import Model #采用通用模型
from keras.layers import Dense, Input #只用到全連接層
import matplotlib.pyplot as plt

#讀取數據
(X_train, _), (X_test, y_test) = mnist.load_data()

2、數據預處理:將28*28維度的數據拉成一個向量784,原數據X_train的shape為(60000,28,28),轉成x_train(60000,784)。

x_train = X_train.astype('float32') / 255. - 0.5       # minmax_normalized

x_test = X_test.astype('float32') / 255. - 0.5         # minmax_normalized

x_train = X_train.reshape((x_train.shape[0], -1))

x_test = X_test.reshape((x_test.shape[0], -1))

print(x_train.shape) #(60000, 784)
print(x_test.shape) #(10000, 784)
print(X_train.shape)  # (60000, 28, 28)

3、編碼層和解碼層的建立+構建模型

# in order to plot in a 2D figure
encoding_dim = 2

# this is our input placeholder
input_img = Input(shape=(784,))


# encoder layers編碼層
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(10, activation='relu')(encoded)
encoder_output = Dense(encoding_dim)(encoded)

# decoder layers解碼層
decoded = Dense(10, activation='relu')(encoder_output)
decoded = Dense(64, activation='relu')(decoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='tanh')(decoded)

#構建模型
#包括編碼層也包括解碼層
autoencoder = Model(input = input_img,output = decoded)
#只包括編碼層
encoder = Model(input = input_img,output = encoder_output)

4、編譯模型

#編譯模型
autoencoder.compile(optimizer='adam', loss='mse')

5、訓練模型【編碼和解碼一起訓練】

autoencoder.fit(x_train, x_train,
                epochs=20,
                batch_size=256,
                shuffle=True)

6、測試模型並畫圖顯示【僅用編碼來預測2維的特征空間】

encoded_imgs = encoder.predict(x_test)
plt.scatter(encoded_imgs[:, 0], encoded_imgs[:, 1], c=y_test) #c表示顏色維度
plt.colorbar()
plt.show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM