Tensorflow BatchNormalization詳解:2_使用tf.layers高級函數來構建神經網絡


Batch Normalization: 使用tf.layers高級函數來構建神經網絡

覺得有用的話,歡迎一起討論相互學習~

我的微博我的github我的B站

參考文獻
吳恩達deeplearningai課程
課程筆記
Udacity課程

# Batch Normalization – Solutions
# Batch Normalization 解決方案
"""
批量標准化在構建深度神經網絡時最為有用。為了證明這一點,我們將創建一個具有20個卷積層的卷積神經網絡,然后是一個完全連接的層。
我們將使用它來對MNIST數據集中的手寫數字進行分類,現在您應該熟悉這一點。這不是划分MNIST數字的最好網絡。您可以創建更簡單的網絡並獲得更好的結果。
但是,為了給您批量標准化的實踐經驗,我們將使用這個作為一個例子:
1:這個網絡足夠復雜,可以保證體現BN算法對深層神經網絡進行訓練時的優勢
2:這個例子比較簡單,你可以很快獲得訓練的結果,這個簡短的練習只是為了給你一次向深度神經玩過中添加BN算法的機會
3:足夠簡單,無需額外資源即可輕松理解架構。
"""
# 這個教程中有兩種你可以自行編輯的在CNN中實現Batch Normalization的方法,
# 第一個是使用高級函數'tf.layers.batch_normalization',
# 第二個使用低級函數'tf.nn.batch_normalization'

# 下載MNIST手寫數字識別數據集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=False)

# Batch Normalization using tf.layers.batch_normalization
# 使用tf.layers.batch_normalization實現Batch Normalization
"""
這個版本的神經網絡代碼使用tf.layers包來編寫,也推薦你使用tf.layers包函數來實現CNN和Batch Normalization算法。
我們將使用以下函數在我們的網絡中創建完全連接的層。我們將用指定數量的神經元和ReLU激活函數來創建它們。
PS:這個版本的函數不包括批量標准化。
"""


def fully_connected(prev_layer, num_units):
    """
    num_units參數傳遞該層神經元的數量,根據prev_layer參數傳入值作為該層輸入創建全連接神經網絡。
    :param prev_layer: Tensor
        該層神經元輸入
    :param num_units: int
        該層神經元結點個數
    :returns Tensor
        一個新的全連接神經網絡層
    """
    layer = tf.layers.dense(prev_layer, num_units, activation=tf.nn.relu)
    return layer


"""
我們會運用以下方法來構建神經網絡的卷積層,這個卷積層很基本,我們總是使用3x3內核,ReLU激活函數,
在具有奇數深度的圖層上步長為1x1,在具有偶數深度的圖層上步長為2x2。在這個網絡中,我們並不打算使用池化層。
PS:該版本的函數不包括批量標准化操作。
"""


def conv_layer(prev_layer, layer_depth):
    """
    Create a convolutional layer with the given layer as input.
    使用給定的參數作為輸入創建卷積層
    :param prev_layer: Tensor
        傳入該層神經元作為輸入
    :param layer_depth: int
        我們將根據網絡中圖層的深度設置特征圖的步長和數量。
        這不是實踐CNN的好方法,但它可以幫助我們用很少的代碼創建這個示例。
    :returns Tensor
        一個新的卷積層
    """
    strides = 2 if layer_depth%3 == 0 else 1
    conv_layer = tf.layers.conv2d(prev_layer, layer_depth*4, 3, strides, 'same', activation=tf.nn.relu)
    return conv_layer


# 建立沒有批量標准化的網絡,然后在MNIST數據集上進行訓練。它在訓練期間定期顯示Loss值和准確性數據


def train(num_batches, batch_size, learning_rate):
    # 為輸入的樣本和標簽創建占位符
    inputs = tf.placeholder(tf.float32, [None, 28, 28, 1])
    labels = tf.placeholder(tf.float32, [None, 10])

    # Feed the inputs into a series of 20 convolutional layers
    # 將輸入數據填充到20個卷積層
    layer = inputs
    for layer_i in range(1, 20):
        layer = conv_layer(layer, layer_i)

    # Flatten the output from the convolutional layers
    # 將卷積層輸出扁平化處理
    orig_shape = layer.get_shape().as_list()
    layer = tf.reshape(layer, shape=[-1, orig_shape[1]*orig_shape[2]*orig_shape[3]])

    # Add one fully connected layer
    # 添加一個具有100個神經元的全連接層
    layer = fully_connected(layer, 100)

    # Create the output layer with 1 node for each
    # 為每一個類別添加一個輸出節點
    logits = tf.layers.dense(layer, 10)

    # 定義
    model_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

    train_opt = tf.train.AdamOptimizer(learning_rate).minimize(model_loss)

    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # Train and test the network
    # 訓練和測試神經網絡
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for batch_i in range(num_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)

            # train this batch
            # 訓練批數據
            sess.run(train_opt, {inputs: batch_xs,
                                 labels: batch_ys})

            # Periodically check the validation or training loss and accuracy
            # 定期檢查訓練或驗證集上的loss和精確度
            if batch_i%100 == 0:
                loss, acc = sess.run([model_loss, accuracy], {inputs: mnist.validation.images,
                                                              labels: mnist.validation.labels})
                print(
                    'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'.format(batch_i, loss, acc))
            elif batch_i%25 == 0:
                loss, acc = sess.run([model_loss, accuracy], {inputs: batch_xs, labels: batch_ys})
                print('Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'.format(batch_i, loss, acc))

        # At the end, score the final accuracy for both the validation and test sets
        # 最后在驗證集和測試集上對模型准確率進行評分
        acc = sess.run(accuracy, {inputs: mnist.validation.images,
                                  labels: mnist.validation.labels})
        print('Final validation accuracy: {:>3.5f}'.format(acc))
        acc = sess.run(accuracy, {inputs: mnist.test.images,
                                  labels: mnist.test.labels})
        print('Final test accuracy: {:>3.5f}'.format(acc))

        # Score the first 100 test images individually, just to make sure batch normalization really worked
        # 對100個獨立的測試圖片進行評分,對比驗證Batch Normalization的效果

        correct = 0
        for i in range(100):
            correct += sess.run(accuracy, feed_dict={inputs: [mnist.test.images[i]],
                                                     labels: [mnist.test.labels[i]]})

        print("Accuracy on 100 samples:", correct/100)


num_batches = 800  # 迭代次數
batch_size = 64  # 批處理數量
learning_rate = 0.002  # 學習率

tf.reset_default_graph()
with tf.Graph().as_default():
    train(num_batches, batch_size, learning_rate)

"""
有了這么多的層次,這個網絡需要大量的迭代來學習。在您完成800個批次的培訓時,您的最終測試和驗證准確度可能不會超過10%。
(每次都會有所不同,但很可能會低於15%)使用批量標准化,您可以在相同數量的批次中訓練同一網絡達到90%以上
使用tf.layers包構建帶有BN層的卷積神經網絡。
"""

# Extracting MNIST_data/train-images-idx3-ubyte.gz
# Extracting MNIST_data/train-labels-idx1-ubyte.gz
# Extracting MNIST_data/t10k-images-idx3-ubyte.gz
# Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
# Batch:  0: Validation loss: 0.69079, Validation accuracy: 0.10700
# Batch: 25: Training loss: 0.33298, Training accuracy: 0.10938
# Batch: 50: Training loss: 0.32532, Training accuracy: 0.07812
# Batch: 75: Training loss: 0.32597, Training accuracy: 0.09375
# Batch: 100: Validation loss: 0.32531, Validation accuracy: 0.11260
# Batch: 125: Training loss: 0.32369, Training accuracy: 0.15625
# Batch: 150: Training loss: 0.32454, Training accuracy: 0.12500
# Batch: 175: Training loss: 0.32519, Training accuracy: 0.14062
# Batch: 200: Validation loss: 0.32540, Validation accuracy: 0.10700
# Batch: 225: Training loss: 0.32509, Training accuracy: 0.06250
# Batch: 250: Training loss: 0.32508, Training accuracy: 0.10938
# Batch: 275: Training loss: 0.32465, Training accuracy: 0.14062
# Batch: 300: Validation loss: 0.32541, Validation accuracy: 0.11260
# Batch: 325: Training loss: 0.32266, Training accuracy: 0.15625
# Batch: 350: Training loss: 0.32408, Training accuracy: 0.06250
# Batch: 375: Training loss: 0.32685, Training accuracy: 0.10938
# Batch: 400: Validation loss: 0.32567, Validation accuracy: 0.10020
# Batch: 425: Training loss: 0.32492, Training accuracy: 0.12500
# Batch: 450: Training loss: 0.32439, Training accuracy: 0.12500
# Batch: 475: Training loss: 0.32574, Training accuracy: 0.12500
# Batch: 500: Validation loss: 0.32554, Validation accuracy: 0.09860
# Batch: 525: Training loss: 0.32668, Training accuracy: 0.03125
# Batch: 550: Training loss: 0.32549, Training accuracy: 0.03125
# Batch: 575: Training loss: 0.32473, Training accuracy: 0.12500
# Batch: 600: Validation loss: 0.32628, Validation accuracy: 0.11260
# Batch: 625: Training loss: 0.32547, Training accuracy: 0.09375
# Batch: 650: Training loss: 0.32518, Training accuracy: 0.17188
# Batch: 675: Training loss: 0.32284, Training accuracy: 0.15625
# Batch: 700: Validation loss: 0.32541, Validation accuracy: 0.10700
# Batch: 725: Training loss: 0.32801, Training accuracy: 0.06250
# Batch: 750: Training loss: 0.32847, Training accuracy: 0.06250
# Batch: 775: Training loss: 0.32251, Training accuracy: 0.20312
# Final validation accuracy: 0.11260
# Final test accuracy: 0.11350
# Accuracy on 100 samples: 0.14


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM