Tensorflow BatchNormalization詳解:4_使用tf.nn.batch_normalization函數實現Batch Normalization操作


使用tf.nn.batch_normalization函數實現Batch Normalization操作

覺得有用的話,歡迎一起討論相互學習~

我的微博我的github我的B站

參考文獻
吳恩達deeplearningai課程
課程筆記
Udacity課程


"""

大多數情況下,您將能夠使用高級功能,但有時您可能想要在較低的級別工作。例如,如果您想要實現一個新特性—一些新的內容,那么TensorFlow還沒有包括它的高級實現,
比如LSTM中的批處理規范化——那么您可能需要知道一些事情。

這個版本的網絡的幾乎所有函數都使用tf.nn包進行編寫,並且使用tf.nn.batch_normalization函數進行標准化操作

'fully_connected'函數的實現比使用tf.layers包進行編寫的要復雜得多。然而,如果你瀏覽了Batch_Normalization_Lesson筆記本,事情看起來應該很熟悉。
為了增加批量標准化,我們做了如下工作:
Added the is_training parameter to the function signature so we can pass that information to the batch normalization layer.
1.在函數聲明中添加'is_training'參數,以確保可以向Batch Normalization層中傳遞信息
2.去除函數中bias偏置屬性和激活函數
3.添加gamma, beta, pop_mean, and pop_variance等變量
4.使用tf.cond函數來解決訓練和預測時的使用方法的差異
5.訓練時,我們使用tf.nn.moments函數來計算批數據的均值和方差,然后在迭代過程中更新均值和方差的分布,並且使用tf.nn.batch_normalization做標准化
  注意:一定要使用with tf.control_dependencies...語句結構塊來強迫Tensorflow先更新均值和方差的分布,再使用執行批標准化操作
6.在前向傳播推導時(特指只進行預測,而不對訓練參數進行更新時),我們使用tf.nn.batch_normalization批標准化時其中的均值和方差分布來自於訓練時我們
  使用滑動平均算法估計的值。
7.將標准化后的值通過RelU激活函數求得輸出
8.不懂請參見https://github.com/udacity/deep-learning/blob/master/batch-norm/Batch_Normalization_Lesson.ipynb
  中關於使用tf.nn.batch_normalization實現'fully_connected'函數的操作
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=False)


def fully_connected(prev_layer, num_units, is_training):
    """
    num_units參數傳遞該層神經元的數量,根據prev_layer參數傳入值作為該層輸入創建全連接神經網絡。

   :param prev_layer: Tensor
        該層神經元輸入
    :param num_units: int
        該層神經元結點個數
    :param is_training: bool or Tensor
        表示該網絡當前是否正在訓練,告知Batch Normalization層是否應該更新或者使用均值或方差的分布信息
    :returns Tensor
        一個新的全連接神經網絡層
    """

    layer = tf.layers.dense(prev_layer, num_units, use_bias=False, activation=None)

    gamma = tf.Variable(tf.ones([num_units]))
    beta = tf.Variable(tf.zeros([num_units]))

    pop_mean = tf.Variable(tf.zeros([num_units]), trainable=False)
    pop_variance = tf.Variable(tf.ones([num_units]), trainable=False)

    epsilon = 1e-3

    def batch_norm_training():
        batch_mean, batch_variance = tf.nn.moments(layer, [0])

        decay = 0.99
        train_mean = tf.assign(pop_mean, pop_mean*decay + batch_mean*(1 - decay))
        train_variance = tf.assign(pop_variance, pop_variance*decay + batch_variance*(1 - decay))

        with tf.control_dependencies([train_mean, train_variance]):
            return tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma, epsilon)

    def batch_norm_inference():
        return tf.nn.batch_normalization(layer, pop_mean, pop_variance, beta, gamma, epsilon)

    batch_normalized_output = tf.cond(is_training, batch_norm_training, batch_norm_inference)
    return tf.nn.relu(batch_normalized_output)


"""
我們對conv_layer卷積層的改變和我們對fully_connected全連接層的改變幾乎差不多。
然而也有很大的區別,卷積層有多個特征圖並且每個特征圖在輸入圖層上共享權重
所以我們需要確保應該針對每個特征圖而不是卷積層上的每個節點進行Batch Normalization操作

為了實現這一點,我們做了與fully_connected相同的事情,有兩個例外:

1.將gamma、beta、pop_mean和pop_方差的大小設置為feature map(輸出通道)的數量,而不是輸出節點的數量。
2.我們改變傳遞給tf.nn的參數。時刻確保它計算正確維度的均值和方差。
"""


def conv_layer(prev_layer, layer_depth, is_training):
    """
       使用給定的參數作為輸入創建卷積層
        :param prev_layer: Tensor
            傳入該層神經元作為輸入
        :param layer_depth: int
            我們將根據網絡中圖層的深度設置特征圖的步長和數量。
            這不是實踐CNN的好方法,但它可以幫助我們用很少的代碼創建這個示例。
        :param is_training: bool or Tensor
            表示該網絡當前是否正在訓練,告知Batch Normalization層是否應該更新或者使用均值或方差的分布信息
        :returns Tensor
            一個新的卷積層
        """
    strides = 2 if layer_depth%3 == 0 else 1

    in_channels = prev_layer.get_shape().as_list()[3]
    out_channels = layer_depth*4

    weights = tf.Variable(
        tf.truncated_normal([3, 3, in_channels, out_channels], stddev=0.05))

    layer = tf.nn.conv2d(prev_layer, weights, strides=[1, strides, strides, 1], padding='SAME')

    gamma = tf.Variable(tf.ones([out_channels]))
    beta = tf.Variable(tf.zeros([out_channels]))

    pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False)
    pop_variance = tf.Variable(tf.ones([out_channels]), trainable=False)

    epsilon = 1e-3

    def batch_norm_training():
        # 一定要使用正確的維度確保計算的是每個特征圖上的平均值和方差而不是整個網絡節點上的統計分布值
        batch_mean, batch_variance = tf.nn.moments(layer, [0, 1, 2], keep_dims=False)

        decay = 0.99
        train_mean = tf.assign(pop_mean, pop_mean*decay + batch_mean*(1 - decay))
        train_variance = tf.assign(pop_variance, pop_variance*decay + batch_variance*(1 - decay))

        with tf.control_dependencies([train_mean, train_variance]):
            return tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma, epsilon)

    def batch_norm_inference():
        return tf.nn.batch_normalization(layer, pop_mean, pop_variance, beta, gamma, epsilon)

    batch_normalized_output = tf.cond(is_training, batch_norm_training, batch_norm_inference)
    return tf.nn.relu(batch_normalized_output)


"""
為了修改訓練函數,我們需要做以下工作:
1.Added is_training, a placeholder to store a boolean value indicating whether or not the network is training.
添加is_training,一個用於存儲布爾值的占位符,該值指示網絡是否正在訓練
2.Each time we call run on the session, we added to feed_dict the appropriate value for is_training.
每次調用sess.run函數時,我們都添加到feed_dict中is_training的適當值用以表示當前是正在訓練還是預測
3.We did not need to add the with tf.control_dependencies... statement that we added in the network that used tf.layers.batch_normalization
because we handled updating the population statistics ourselves in conv_layer and fully_connected.
我們不需要將train_opt訓練函數放進with tf.control_dependencies... 的函數結構體中,這是只有在使用tf.layers.batch_normalization才做的更新均值和方差的操作

"""


def train(num_batches, batch_size, learning_rate):
    # Build placeholders for the input samples and labels
    # 創建輸入樣本和標簽的占位符
    inputs = tf.placeholder(tf.float32, [None, 28, 28, 1])
    labels = tf.placeholder(tf.float32, [None, 10])

    # Add placeholder to indicate whether or not we're training the model
    # 創建占位符表明當前是否正在訓練模型
    is_training = tf.placeholder(tf.bool)

    # Feed the inputs into a series of 20 convolutional layers
    # 把輸入數據填充到一系列20個卷積層的神經網絡中
    layer = inputs
    for layer_i in range(1, 20):
        layer = conv_layer(layer, layer_i, is_training)

    # Flatten the output from the convolutional layers
    # 將卷積層輸出扁平化處理
    orig_shape = layer.get_shape().as_list()
    layer = tf.reshape(layer, shape=[-1, orig_shape[1]*orig_shape[2]*orig_shape[3]])

    # Add one fully connected layer
    # 添加一個具有100個神經元的全連接層
    layer = fully_connected(layer, 100, is_training)

    # Create the output layer with 1 node for each
    # 為每一個類別添加一個輸出節點
    logits = tf.layers.dense(layer, 10)

    # Define loss and training operations
    # 定義loss 函數和訓練操作
    model_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
    train_opt = tf.train.AdamOptimizer(learning_rate).minimize(model_loss)

    # Create operations to test accuracy
    # 創建計算准確度的操作
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # Train and test the network
    # 訓練並測試網絡模型
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for batch_i in range(num_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)

            # train this batch
            # 訓練樣本批次
            sess.run(train_opt, {inputs: batch_xs, labels: batch_ys, is_training: True})

            # Periodically check the validation or training loss and accuracy
            # 定期檢查訓練或驗證集上的loss和精確度
            if batch_i%100 == 0:
                loss, acc = sess.run([model_loss, accuracy], {inputs: mnist.validation.images,
                                                              labels: mnist.validation.labels,
                                                              is_training: False})
                print(
                    'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'.format(batch_i, loss, acc))
            elif batch_i%25 == 0:
                loss, acc = sess.run([model_loss, accuracy], {inputs: batch_xs, labels: batch_ys, is_training: False})
                print('Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'.format(batch_i, loss, acc))

        # At the end, score the final accuracy for both the validation and test sets
        # 最后在驗證集和測試集上對模型准確率進行評分
        acc = sess.run(accuracy, {inputs: mnist.validation.images,
                                  labels: mnist.validation.labels,
                                  is_training: False})
        print('Final validation accuracy: {:>3.5f}'.format(acc))
        acc = sess.run(accuracy, {inputs: mnist.test.images,
                                  labels: mnist.test.labels,
                                  is_training: False})
        print('Final test accuracy: {:>3.5f}'.format(acc))

        # Score the first 100 test images individually, just to make sure batch normalization really worked
        # 對100個獨立的測試圖片進行評分,對比驗證Batch Normalization的效果
        correct = 0
        for i in range(100):
            correct += sess.run(accuracy, feed_dict={inputs: [mnist.test.images[i]],
                                                     labels: [mnist.test.labels[i]],
                                                     is_training: False})

        print("Accuracy on 100 samples:", correct/100)


num_batches = 800  # 迭代次數
batch_size = 64  # 批處理數量
learning_rate = 0.002  # 學習率

tf.reset_default_graph()
with tf.Graph().as_default():
    train(num_batches, batch_size, learning_rate)


"""
再一次,批量標准化的模型很快達到了很高的精度。
但是在我們的運行中,注意到它似乎並沒有學習到前250個批次的任何東西,然后精度開始上升。
這只是顯示——即使是批處理標准化,給您的網絡一些時間來學習是很重要的。

PS:再100個單個數據的預測上達到了較高的精度,而這才是BN算法真正關注的!!
"""
# Extracting MNIST_data/train-images-idx3-ubyte.gz
# Extracting MNIST_data/train-labels-idx1-ubyte.gz
# Extracting MNIST_data/t10k-images-idx3-ubyte.gz
# Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
# 2018-03-18 19:35:28.568404: I D:\Build\tensorflow\tensorflow-r1.4\tensorflow\core\platform\cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX
# Batch:  0: Validation loss: 0.69113, Validation accuracy: 0.10020
# Batch: 25: Training loss: 0.57341, Training accuracy: 0.07812
# Batch: 50: Training loss: 0.45526, Training accuracy: 0.04688
# Batch: 75: Training loss: 0.37936, Training accuracy: 0.12500
# Batch: 100: Validation loss: 0.34601, Validation accuracy: 0.10700
# Batch: 125: Training loss: 0.34113, Training accuracy: 0.12500
# Batch: 150: Training loss: 0.33075, Training accuracy: 0.12500
# Batch: 175: Training loss: 0.34333, Training accuracy: 0.15625
# Batch: 200: Validation loss: 0.37085, Validation accuracy: 0.09860
# Batch: 225: Training loss: 0.40175, Training accuracy: 0.09375
# Batch: 250: Training loss: 0.48562, Training accuracy: 0.06250
# Batch: 275: Training loss: 0.67897, Training accuracy: 0.09375
# Batch: 300: Validation loss: 0.48383, Validation accuracy: 0.09880
# Batch: 325: Training loss: 0.43822, Training accuracy: 0.14062
# Batch: 350: Training loss: 0.43227, Training accuracy: 0.18750
# Batch: 375: Training loss: 0.39464, Training accuracy: 0.37500
# Batch: 400: Validation loss: 0.50557, Validation accuracy: 0.25940
# Batch: 425: Training loss: 0.32337, Training accuracy: 0.59375
# Batch: 450: Training loss: 0.14016, Training accuracy: 0.75000
# Batch: 475: Training loss: 0.11652, Training accuracy: 0.78125
# Batch: 500: Validation loss: 0.06241, Validation accuracy: 0.91280
# Batch: 525: Training loss: 0.01880, Training accuracy: 0.96875
# Batch: 550: Training loss: 0.03640, Training accuracy: 0.93750
# Batch: 575: Training loss: 0.07202, Training accuracy: 0.90625
# Batch: 600: Validation loss: 0.03984, Validation accuracy: 0.93960
# Batch: 625: Training loss: 0.00692, Training accuracy: 0.98438
# Batch: 650: Training loss: 0.01251, Training accuracy: 0.96875
# Batch: 675: Training loss: 0.01823, Training accuracy: 0.96875
# Batch: 700: Validation loss: 0.03951, Validation accuracy: 0.94080
# Batch: 725: Training loss: 0.02886, Training accuracy: 0.95312
# Batch: 750: Training loss: 0.06396, Training accuracy: 0.87500
# Batch: 775: Training loss: 0.02013, Training accuracy: 0.98438
# Final validation accuracy: 0.95820
# Final test accuracy: 0.95780
# Accuracy on 100 samples: 0.98


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM