tensorflow 計算均值和方差


我們在處理矩陣數據時,需要用到數據的均值和方差,比如在batch normalization的時候。

那么,tensorflow中計算均值和方差的函數是:tf.nn.moments(x, axes)

x: 我們待處理的數據

axes: 在哪一個維度上求解,是一個list,如axes=[0, 1, 2]

舉例:

 1 def calc_mean_variance():
 2     """
 3         計算均值和方差
 4     :return:
 5     """
 6     img = tf.Variable(tf.random_normal([2, 3]))
 7     t = len(img.get_shape())
 8     axis = list(range(len(img.get_shape()) - 1))
 9     mean, variance = tf.nn.moments(img, axes=0)
10     with tf.Session() as sess:
11         sess.run(tf.global_variables_initializer())
12         print(sess.run(img))
13         print(sess.run([mean, variance]))

輸出:

 

注意,以下是統計軸的個數:

axis = list(range(len(img.get_shape()) - 1))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM