tensorflow中的batch_normalization實現


  tensorflow中實現batch_normalization的函數主要有兩個:

    1)tf.nn.moments

    2)tf.nn.batch_normalization

  tf.nn.moments主要是用來計算均值mean和方差variance的值,這兩個值被用在之后的tf.nn.batch_normalization中

  tf.nn.moments(x, axis,...)

  主要有兩個參數:輸入的batchs數據;進行求均值和方差的維度axis,axis的值是一個列表,可以傳入多個維度

  返回值:mean和variance

  tf.nn.batch_normalization(x, mean, variance, offset, scala, variance_epsilon)

  主要參數:輸入的batchs數據;mean;variance;offset和scala,這兩個參數是要學習的參數,所以只要給出初始值,一般offset=0,scala=1;variance_epsilon是為了保證variance為0時,除法仍然可行,設置為一個較小的值即可

  輸出:bn處理后的數據

  具體代碼如下:    

import tensorflow as tf
import numpy as np


X = tf.constant(np.random.uniform(1, 10, size=(3, 3)), dtype=tf.float32)
axis = list(range(len(X.get_shape()) - 1))
mean, variance = tf.nn.moments(X, axis)
print(axis)

X_batch = tf.nn.batch_normalization(X, mean, variance, 0, 1, 0.001)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    mean, variance, X_batch = sess.run([mean, variance, X_batch])
    print(mean)
    print(variance)
    print(X_batch)

輸出:

axis: [0]
mean: [5.124098 3.0998185 4.723417 ]
variance: [3.7908943 1.7062012 3.8243492]
X_batch: [[-0.32879925 -1.3645337 0.39226937]
      [-1.0266179 0.36186576 -1.3726556 ]
      [ 1.355417 1.0026684 0.98038626]]

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM