論文:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Internal Covariate Shift
深度神經網絡涉及到很多層的疊加,而每一層的參數更新會導致上層的輸入數據分布發生變化,通過層層疊加,高層的輸入分布變化會非常劇烈,這就使得高層需要不斷去重新適應底層的參數更新。為了訓好模型,我們需要非常謹慎地去設定學習率、初始化權重、以及盡可能細致的參數更新策略。
Google 將這一現象總結為 Internal Covariate Shift,簡稱 ICS.
BatchNorm的基本思想:能不能讓每個隱層節點的激活輸入分布固定下來呢?這樣就避免了“Internal Covariate Shift”問題了。
Mini-Batch SGD
BatchNorm是基於Mini-Batch SGD的。SGD訓練的缺點:超參數調起來很麻煩。
所謂“Mini-Batch”,是指的從訓練數據全集T中隨機選擇的一個訓練數據子集合。假設訓練數據集合T包含N個樣本,而每個Mini-Batch的Batch Size為b,於是整個訓練數據可被分成N/b個Mini-Batch。在模型通過SGD進行訓練時,一般跑完一個Mini-Batch的實例,叫做完成訓練的一步(step),跑完N/b步則整個訓練數據完成一輪訓練,則稱為完成一個Epoch。完成一個Epoch訓練過程后,對訓練數據做隨機Shuffle打亂訓練數據順序,重復上述步驟,然后開始下一個Epoch的訓練,對模型完整充分的訓練由多輪Epoch構成(參考圖1)。
在拿到一個Mini-Batch進行參數更新時,首先根據當前Mini-Batch內的b個訓練實例以及參數對應的損失函數的偏導數來進行計算,以獲得參數更新的梯度方向,然后根據SGD算法進行參數更新,以此來達到本步(Step)更新模型參數並逐步尋優的過程。
批歸一化的本質和意義
算法 1:批歸一化變換,在一個 mini-batch 上應用於激活 x。
批歸一化是一種用於訓練神經網絡模型的有效方法,是在深度神經網絡訓練過程中使得每一層神經網絡的輸入保持相同分布的。
BatchNorm的目標是對特征進行歸一化處理(使每層網絡的輸出都經過激活),得到標准差為 1 的零均值狀態。所以其相反的現象是非零均值。這將如何影響模型的訓練: 首先,這可以被理解成非零均值是數據不圍繞 0 值分布的現象,而是數據的大多數值大於 0 或小於 0。結合高方差問題,數據會變得非常大或非常小。在訓練層數很多的神經網絡時,這個問題很常見。如果特征不是分布在穩定的區間(從小到大的值)里,那么就會對網絡的優化過程產生影響。我們都知道,優化神經網絡將需要用到導數計算。其實就是把越來越偏的分布強制拉回比較標准的分布,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導致損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題產生,而且梯度變大意味着學習收斂速度快,能大大加快訓練速度。
假設一個簡單的層計算$y = (W_{x} + b)$,y 在 W 上的導數就是這樣:$d_{y}=dW_{x}$。因此,x 的值會直接影響導數的值(當然,神經網絡模型的梯度概念不會如此之簡單,但理論上,x 會影響導數)。因此,如果 x 引入了不穩定的變化,則這個導數要么過大,要么就過小,最終導致學習到的模型不穩定。而這也意味着當使用批歸一化時,我們可以在訓練中使用更高的學習率。
批歸一化可幫助我們避免 x 的值在經過非線性激活函數之后陷入飽和的現象。也就是說,批歸一化能夠確保激活都不會過高或過低。這有助於權重學習——如果不使用這一方案,某些權重可能永遠不會學習。這還能幫助我們降低對參數的初始值的依賴。
批歸一化也可用作正則化(regularization)的一種形式,有助於實現過擬合的最小化。使用批歸一化時,我們無需再使用過多的 dropout;這是很有助益的,因為我們無需擔心再執行 dropout 時丟失太多信息。但是,仍然建議組合使用這兩種技術。
均值為0,方差為1的標准正態分布代表什么含義:
這意味着在一個標准差范圍內,也就是說64%的概率x其值落在[-1,1]的范圍內,在兩個標准差范圍內,也就是說95%的概率x其值落在了[-2,2]的范圍內。那么這又意味着什么?我們知道,激活值x=WU+B,U是真正的輸入,x是某個神經元的激活值,假設非線性函數是sigmoid,那么看下sigmoid(x)其圖形:
圖:Sigmoid(x)
sigmoid(x)的導數為:G’=f(x)*(1-f(x)),因為f(x)=sigmoid(x)在0到1之間,所以G’在0到0.25之間,其對應的圖如下:
假設沒有經過BN調整前x的原先正態分布均值是-6,方差是1,那么意味着95%的值落在了[-8,-4]之間,那么對應的Sigmoid(x)函數的值明顯接近於0,這是典型的梯度飽和區,在這個區域里梯度變化很慢,為什么是梯度飽和區?請看下sigmoid(x)如果取值接近0或者接近於1的時候對應導數函數取值,接近於0,意味着梯度變化很小甚至消失。而假設經過BN后,均值是0,方差是1,那么意味着95%的x值落在了[-2,2]區間內,很明顯這一段是sigmoid(x)函數接近於線性變換的區域,意味着x的小變化會導致非線性函數值較大的變化,也即是梯度變化較大,對應導數函數圖中明顯大於0的區域,就是梯度非飽和區。
批歸一化的局限性
局限1:BN是嚴重依賴Mini-Batch中的訓練實例的
如果Batch Size比較小則任務效果有明顯的下降。在小的BatchSize意味着數據樣本少,因而得不到有效統計量,也就是說噪音太大。
局限2:對於有些像素級圖片生成任務來說,BN效果不佳
對於圖片分類等任務,只要能夠找出關鍵特征,就能正確分類,這算是一種粗粒度的任務,在這種情形下通常BN是有積極效果的。但是對於有些輸入輸出都是圖片的像素級別圖片生成任務,比如圖片風格轉換等應用場景,使用BN會帶來負面效果,這很可能是因為在Mini-Batch內多張無關的圖片之間計算統計量,弱化了單張圖片本身特有的一些細節信息。
局限3:RNN等動態網絡使用BN效果不佳且使用起來不方便 !
對於RNN來說,盡管其結構看上去是個靜態網絡,但在實際運行展開時是個動態網絡結構,因為輸入的Sequence序列是不定長的,這源自同一個Mini-Batch中的訓練實例有長有短。對於類似RNN這種動態網絡結構,BN使用起來不方便,因為要應用BN,那么RNN的每個時間步需要維護各自的統計量,而Mini-Batch中的訓練實例長短不一,這意味着RNN不同時間步的隱層會看到不同數量的輸入數據,而這會給BN的正確使用帶來問題。假設Mini-Batch中只有個別特別長的例子,那么對較深時間步深度的RNN網絡隱層來說,其統計量不方便統計而且其統計有效性也非常值得懷疑。另外,如果在推理階段遇到長度特別長的例子,也許根本在訓練階段都無法獲得深層網絡的統計量。綜上,在RNN這種動態網絡中使用BN很不方便,而且很多改進版本的BN應用在RNN效果也一般。
局限4:訓練時和推理時統計量不一致
對於BN來說,采用Mini-Batch內實例來計算統計量,這在訓練時沒有問題,但是在模型訓練好之后,在線推理的時候會有麻煩。因為在線推理或預測的時候,是單實例的,不存在Mini-Batch,所以就無法獲得BN計算所需的均值和方差,一般解決方法是采用訓練時刻記錄的各個Mini-Batch的統計量的數學期望,以此來推算全局的均值和方差,在線推理時采用這樣推導出的統計量。雖說實際使用並沒大問題,但是確實存在訓練和推理時刻統計量計算方法不一致的問題。
綜上,共同點就是BN要求計算統計量的時候必須在同一個Mini-Batch內的實例之間進行統計,因此形成了Batch內實例之間的相互依賴和影響的關系。如何從根本上解決這些問題?一個自然的想法是:把對Batch的依賴去掉,轉換統計集合范圍。在統計均值方差的時候,不依賴Batch內數據,只用當前處理的單個訓練數據來獲得均值方差的統計量,這樣因為不再依賴Batch內其它訓練數據,那么就不存在因為Batch約束導致的問題。在BN后的幾乎所有改進模型都是在這個指導思想下進行的。
但是這個指導思路盡管會解決BN帶來的問題,又會引發新的問題,新的問題是:我們目前已經沒有Batch內實例能夠用來求統計量了,此時統計范圍必須局限在一個訓練實例內,一個訓練實例看上去孤零零的無依無靠沒有組織,怎么看也無法求統計量,所以核心問題是對於單個訓練實例,統計范圍怎么算?
====> 請看下一篇:
Layer Normalization、Instance Normalization及Group Normalization