Batch Normalization和Dropout是深度學習模型中常用的結構。但BN和dropout在訓練和測試時使用卻不相同。
Batch Normalization
BN在訓練時是在每個batch上計算均值和方差來進行歸一化,每個batch的樣本量都不大,所以每次計算出來的均值和方差就存在差異。預測時一般傳入一個樣本,所以不存在歸一化,其次哪怕是預測一個batch,但batch計算出來的均值和方差是偏離總體樣本的,所以通常是通過滑動平均結合訓練時所有batch的均值和方差來得到一個總體均值和方差。以tensorflow代碼實現為例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5): # 獲取輸入維度並判斷是否匹配卷積層(4)或者全連接層(2) shape = inputs.shape param_shape = shape[-1] with tf.variable_scope(name): # 聲明BN中唯一需要學習的兩個參數,y=gamma*x+beta gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1)) beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0)) # 計算當前整個batch的均值與方差 axes = list(range(len(shape)-1)) batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments') # 采用滑動平均更新均值與方差 ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema") def mean_var_with_update(): ema_apply_op = ema.apply([batch_mean, batch_var]) with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) # 訓練時,更新均值與方差,測試時使用之前最后一次保存的均值與方差 mean, var = tf.cond(tf.equal(training,True), mean_var_with_update, lambda:(ema.average(batch_mean), ema.average(batch_var))) # 最后執行batch normalization return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)
training參數可以通過tf.placeholder傳入,這樣就可以控制訓練和預測時training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在訓練時會隨機丟棄一些神經元,這樣會導致輸出的結果變小。而預測時往往關閉dropout,保證預測結果的一致性(不關閉dropout可能同一個輸入會得到不同的輸出,不過輸出會服從某一分布。另外有些情況下可以不關閉dropout,比如文本生成下,不關閉會增大輸出的多樣性)。
為了對齊Dropout訓練和預測的結果,通常有兩種做法,假設dropout rate = 0.2。一種是訓練時不做處理,預測時輸出乘以(1 - dropout rate)。另一種是訓練時留下的神經元除以(1 - dropout rate),預測時不做處理。以tensorflow為例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二種做法,訓練時除以(1 - dropout rate),源碼如下:
binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor if not context.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
binary_tensor就是一個mask tensor,即里面的值由0或1組成。keep_prob = 1 - dropout rate。