TensorFlow自編碼器(AutoEncoder)之MNIST實踐


 

自編碼器可以用於降維,添加噪音學習也可以獲得去噪的效果。

以下使用單隱層訓練mnist數據集,並且共享了對稱的權重參數。

模型本身不難,調試的過程中有幾個需要注意的地方:

  • 模型對權重參數初始值敏感,所以這里對權重參數w做了一些限制
  • 需要對數據標准化
  • 學習率設置合理(Adam,0.001)

1,建立模型

import numpy as np
import tensorflow as tf

class AutoEncoder(object):
    '''
    使用對稱結構,解碼器重用編碼器的權重參數
    '''
    def __init__(self, input_shape, h1_size, lr):
        tf.reset_default_graph()# 重置默認計算圖,有時出錯后內存還一團糟
        with tf.variable_scope('auto_encoder', reuse=tf.AUTO_REUSE):
            self.W1 = self.weights(shape=(input_shape, h1_size), name='h1')
            self.b1 = self.bias(h1_size)
            self.W2 = tf.transpose(tf.get_variable('h1'))  # 共享參數,使用其轉置
            self.b2 = self.bias(input_shape)
        self.lr = lr
        self.input = tf.placeholder(shape=(None, input_shape),
                                    dtype=tf.float32)
        self.h1_out = tf.nn.softplus(tf.matmul(self.input, self.W1) + self.b1)# softplus,類relu
        self.out = tf.matmul(self.h1_out, self.W2) + self.b2
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
        self.loss = 0.1 * tf.reduce_sum(
            tf.pow(tf.subtract(self.input, self.out), 2))
        self.train_op = self.optimizer.minimize(self.loss)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

    def fit(self, X, epoches=100, batch_size=128, epoches_to_display=10):
        batchs_per_epoch = X.shape[0] // batch_size
        for i in range(epoches):
            epoch_loss = []
            for j in range(batchs_per_epoch):
                X_train = X[j * batch_size:(j + 1) * batch_size]
                loss, _ = self.sess.run([self.loss, self.train_op],
                                        feed_dict={self.input: X_train})
                epoch_loss.append(loss)
            if i % epoches_to_display == 0:
                print('avg_loss at epoch %d :%f' % (i, np.mean(epoch_loss)))
#         return self.sess.run(W1)


#     權重初始化參考別人的,這個居然很重要!用自己設定的截斷正態分布隨機沒有效果
    def weights(self, shape, name, constant=1):
        fan_in = shape[0]
        fan_out = shape[1]
        low = -constant * np.sqrt(6.0 / (fan_in + fan_out))
        high = constant * np.sqrt(6.0 / (fan_in + fan_out))
        init = tf.random_uniform_initializer(minval=low, maxval=high)
        return tf.get_variable(name=name,
                               shape=shape,
                               initializer=init,
                               dtype=tf.float32)

    def bias(self, size):
        return tf.Variable(tf.constant(0, dtype=tf.float32, shape=[size]))

    def encode(self, X):
        return self.sess.run(self.h1_out, feed_dict={self.input: X})

    def decode(self, h):
        return self.sess.run(self.out, feed_dict={self.h1_out: h})

    def reconstruct(self, X):
        return self.sess.run(self.out, feed_dict={self.input: X})

2,加載數據及預處理

from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

import random
X_train = X_train.reshape(-1, 784)
# 測試集里隨機10個圖片用做測試
test_idxs = random.sample(range(X_test.shape[0]), 10)
data_test = X_test[test_idxs].reshape(-1, 784)
# 標准化
import sklearn.preprocessing as prep
processer = prep.StandardScaler().fit(X_train) # 這里還是用全部數據好,這個也很關鍵!
X_train = processer.transform(X_train)
X_test = processer.transform(data_test)

# 隨機5000張圖片用做訓練
idxs = random.sample(range(X_train.shape[0]), 5000)
data_train = X_train[idxs]

3,訓練

model = AutoEncoder(784, 200, 0.001)  # 學習率對loss影響也有點大
model.fit(data_train, batch_size=128, epoches=200)  # 200輪即可

4,測試,可視化對比圖

decoded_test = model.reconstruct(X_test)

import matplotlib.pyplot as plt
%matplotlib inline
shape = (28, 28)
fig, axes = plt.subplots(2,10,
                         figsize=(10, 2),
                         subplot_kw={
                             'xticks': [],
                             'yticks': []
                         },
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
    axes[0][i].imshow(np.reshape(X_test[i], shape))
    axes[1][i].imshow(np.reshape(decoded_test[i], shape))
plt.show()

結果如下:

 

 以上,可以在輸入中添加點高斯噪音,增加魯棒性。

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM