【學習筆記】Devils in BatchNorm


Devils in BatchNorm

Facebook人工智能實驗室研究工程師吳育昕

該視頻主要討論Batch Normalization的一些坑。
Batch Norm后還有一個channel-wise仿射,是一個退化的卷積層,這里不討論。


Batch Norm的訓練和測試具有不一致性,測試時vanilla BN方法是更新一個exponential moving average,也就是圖中的\(u_{EMA}\)

  • 為什么可以訓練和測試不一致?
    DropOut和Data Augmentation也是這樣——可以理解為訓練是在測試的基礎上加噪聲,測試是訓練的平均。

不過噪聲本身也是一種正則化。

  • BN什么時候會失敗?
    \(\mu_{EMA}\)\(\sigma_{EMA}\)不接近\(\mu_{B}\)\(\sigma_{B}\)
  1. 當EMA計算不合理
  2. \(\mu_{B}\)\(\sigma_{B}\)不穩定時 - 不能很好地近似
    a)數據不穩定
    b)不穩定的模型
  • EMA計算不合理的情況

\[\mu_{E M A} \leftarrow \lambda \mu_{E M A}+(1-\lambda) \mu_{B}, \sigma_{E M A}^{2} \leftarrow \cdots \]

  1. \(\lambda\)過小,EMA
  2. \(\lambda\)過大,需要很多次迭代
  3. 不穩定的模型或最后N次迭代中不穩定的數據

常見的錯誤是——"false overfitting",在可能出現overfitting時但是迭代次數又很少時需仔細甄別

  • EMA不合理之處
  1. 總是有偏置的
  2. 數據的分布總是在變化
  3. 並不是真的平均
  • 解決方案:Precise BatchNorm
    最早來源於ResNet
    實現:
    · Cheap Precise BN:繼續使用EMA但是使用大的\(\lambda\),把模型固定,forward很多(比如1000次)迭代
    · 先算前一層的PreciseBN,用這個再算下一層PreciseBN

BN在訓練/微調上的坑

Normalization batch size

  1. Norm batch size不一定等於SGD batch size,受顯卡顯存的限制
  2. 一個batch中,均值和方差是有噪聲的——上面提到訓練的均值和方差可看作在測試的基礎上加噪聲,若一個batch中有一個異常sample就帶來噪聲
  • 如何增大Normalization batch size?
  1. Sync BatchNorm/Cross-GPU BN
    其實現是采用all-reduce \(2 \times C\) elements。
    overhead也很小。在各個框架上都有實現。
  2. Virtual BatchNorm
    使用很多只為了前向的圖片,不會顯著增加顯存,但是會增加時間。
    唯一好處是可控,適用於reasearch和analysis。
  • 如何減小Normalization batch size?
    Ghost BN
    其實現是在一個batch中分離

  • 如何在改變SGD的batch size同時控制NBS不變?
    使用Accumulate Gradients。
    其實現是積累幾次迭代的梯度后將gradients平均再去更新模型。

  • NBS特別小時的解決方案
    Batch Renormalization。
    訓練: \(\hat{x}=\frac{x-\mu_{B}}{\sigma_{B}} \times\) stop gradient \((r)+\) stop gradient \((d)\)
    測試: \(\hat{x}=\frac{x-\mu_{E M A}}{\sigma_{E M A}}\)
    \(r, d\) pushes \(\mu_{B}, \sigma_{B}\) similar to \(\mu_{E M A}, \sigma_{E M A}\)
    Reduce noise \(\&\) inconsistency
    Need to tune the limit on \(r, d\)


BN在數據分布的分布

數據非獨立同分布時容易出現BN會學習到一些捷徑
一般發生在:

  1. 多域學習
  2. 對抗訓練
  3. fine-tuning
    一些解決的tricks:
  4. 訓練時——為各個domain做Seperate BN
  5. 訓練/微調時——Frozen BN(Sync BN沒出現前使用,一般不全部用於train from scratch,用於fine-tune或是train時模型的末端)
  6. 測試時——Adaptive BN

GAN中遇到的real/fake分布

在判別器中,會有兩個分布,希望只有一個去更新EMA:

  1. decoder(real_batch,training=True)
  2. decoder(fake_batch,training=True,update_ema=False)# don't update EMA或decoder(fake_batch,training=False)# use EMA during training

batch本來的設計就來源於相關源

  1. two-stage目標檢測器中batch本身就有來自同一張圖片的patch組成 -> 解決:Group Norm
  2. 視頻理解

強化學習

數據就來自於模型,解決方法是:
DQN中提出的target network或是Precise BN


BN在融合上的坑


BN在實現上的坑

PyTorch中momentum的0.1是別人的0.9,而且及其需要注意track_running_stats的使用
TensorFlow中EMA的更新不是在層計算的同時發生,新手容易忘記更新EMA更新的操作加入到訓練中,解決方法是使用tensorpack.models.BatchNorm;


TensorFlow實現BN

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
    with tf.variable_scope(scope):
        # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
        # gamma = tf.get_variable(name='gamma', shape=[n_out],
        #                         initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
        batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=decay)
  
        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean,batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean),tf.identity(batch_var)
                # identity之后會把Variable轉換為Tensor並入圖中,
                # 否則由於Variable是獨立於Session的,不會被圖控制control_dependencies限制
  
        mean,var = tf.cond(phase_train,
                           mean_var_with_update,
                           lambda: (ema.average(batch_mean),ema.average(batch_var)))
       normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
    return normed

總結

  1. 使用哪個\(\mu, \sigma ?\)
    \(\mu_{B}\), \(\sigma_{B}\) ; \(\mu_{E M A}\), \(\sigma_{E M A}\) ; Batch ReNorm

  2. 如何計算\(\mu_{B}\), \(\sigma_{B}\):
    Per-GPU BN,Sync BN,Ghost BN,Virtual BN

  3. 是否更新\(\mu_{E M A}\), \(\sigma_{E M A}\)With \(\mu_{B}\), \(\sigma_{B}\):
    YES,NO,Separate BN

  4. 測試/微調時用什么:
    EMA,Precise BN,Adaptive BN,Frozen BN


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM