Devils in BatchNorm
Facebook人工智能實驗室研究工程師吳育昕
該視頻主要討論Batch Normalization的一些坑。
Batch Norm后還有一個channel-wise仿射,是一個退化的卷積層,這里不討論。
Batch Norm的訓練和測試具有不一致性,測試時vanilla BN方法是更新一個exponential moving average,也就是圖中的\(u_{EMA}\)
- 為什么可以訓練和測試不一致?
DropOut和Data Augmentation也是這樣——可以理解為訓練是在測試的基礎上加噪聲,測試是訓練的平均。
不過噪聲本身也是一種正則化。
- BN什么時候會失敗?
當\(\mu_{EMA}\),\(\sigma_{EMA}\)不接近\(\mu_{B}\)和\(\sigma_{B}\)時
- 當EMA計算不合理
- 當\(\mu_{B}\),\(\sigma_{B}\)不穩定時 - 不能很好地近似
a)數據不穩定
b)不穩定的模型
- EMA計算不合理的情況
- \(\lambda\)過小,EMA
- \(\lambda\)過大,需要很多次迭代
- 不穩定的模型或最后N次迭代中不穩定的數據
常見的錯誤是——"false overfitting",在可能出現overfitting時但是迭代次數又很少時需仔細甄別
- EMA不合理之處
- 總是有偏置的
- 數據的分布總是在變化
- 並不是真的平均
- 解決方案:Precise BatchNorm
最早來源於ResNet
實現:
· Cheap Precise BN:繼續使用EMA但是使用大的\(\lambda\),把模型固定,forward很多(比如1000次)迭代
· 先算前一層的PreciseBN,用這個再算下一層PreciseBN
BN在訓練/微調上的坑
Normalization batch size
- Norm batch size不一定等於SGD batch size,受顯卡顯存的限制
- 一個batch中,均值和方差是有噪聲的——上面提到訓練的均值和方差可看作在測試的基礎上加噪聲,若一個batch中有一個異常sample就帶來噪聲
- 如何增大Normalization batch size?
- Sync BatchNorm/Cross-GPU BN
其實現是采用all-reduce \(2 \times C\) elements。
overhead也很小。在各個框架上都有實現。 - Virtual BatchNorm
使用很多只為了前向的圖片,不會顯著增加顯存,但是會增加時間。
唯一好處是可控,適用於reasearch和analysis。
-
如何減小Normalization batch size?
Ghost BN
其實現是在一個batch中分離 -
如何在改變SGD的batch size同時控制NBS不變?
使用Accumulate Gradients。
其實現是積累幾次迭代的梯度后將gradients平均再去更新模型。 -
NBS特別小時的解決方案
Batch Renormalization。
訓練: \(\hat{x}=\frac{x-\mu_{B}}{\sigma_{B}} \times\) stop gradient \((r)+\) stop gradient \((d)\)
測試: \(\hat{x}=\frac{x-\mu_{E M A}}{\sigma_{E M A}}\)
\(r, d\) pushes \(\mu_{B}, \sigma_{B}\) similar to \(\mu_{E M A}, \sigma_{E M A}\)
Reduce noise \(\&\) inconsistency
Need to tune the limit on \(r, d\)
BN在數據分布的分布
數據非獨立同分布時容易出現BN會學習到一些捷徑
一般發生在:
- 多域學習
- 對抗訓練
- fine-tuning
一些解決的tricks: - 訓練時——為各個domain做Seperate BN
- 訓練/微調時——Frozen BN(Sync BN沒出現前使用,一般不全部用於train from scratch,用於fine-tune或是train時模型的末端)
- 測試時——Adaptive BN
GAN中遇到的real/fake分布
在判別器中,會有兩個分布,希望只有一個去更新EMA:
- decoder(real_batch,training=True)
- decoder(fake_batch,training=True,update_ema=False)# don't update EMA或decoder(fake_batch,training=False)# use EMA during training
batch本來的設計就來源於相關源
- two-stage目標檢測器中batch本身就有來自同一張圖片的patch組成 -> 解決:Group Norm
- 視頻理解
強化學習
數據就來自於模型,解決方法是:
DQN中提出的target network或是Precise BN
BN在融合上的坑
BN在實現上的坑
PyTorch中momentum的0.1是別人的0.9,而且及其需要注意track_running_stats
的使用;
TensorFlow中EMA的更新不是在層計算的同時發生,新手容易忘記更新EMA更新的操作加入到訓練中,解決方法是使用tensorpack.models.BatchNorm;
TensorFlow實現BN
def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后會把Variable轉換為Tensor並入圖中,
# 否則由於Variable是獨立於Session的,不會被圖控制control_dependencies限制
mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed
總結
-
使用哪個\(\mu, \sigma ?\)
\(\mu_{B}\), \(\sigma_{B}\) ; \(\mu_{E M A}\), \(\sigma_{E M A}\) ; Batch ReNorm -
如何計算\(\mu_{B}\), \(\sigma_{B}\):
Per-GPU BN,Sync BN,Ghost BN,Virtual BN -
是否更新\(\mu_{E M A}\), \(\sigma_{E M A}\)With \(\mu_{B}\), \(\sigma_{B}\):
YES,NO,Separate BN -
測試/微調時用什么:
EMA,Precise BN,Adaptive BN,Frozen BN