理論板塊將從以下四個方面對Batch Normalization進行詳解:
- 提出背景
- BN算法思想
- 測試階段如何使用BN
- BN的優勢
理論部分主要參考2015年Google的Sergey Ioffe與Christian Szegedy的論文內容,並輔以吳恩達Coursera課程與其它博主的資料。所有參考內容鏈接均見於文章最后參考鏈接部分。
1 提出背景
1.1 煉丹的困擾
在深度學習中,由於問題的復雜性,我們往往會使用較深層數的網絡進行訓練,相信很多煉丹的朋友都對調參的困難有所體會,尤其是對深層神經網絡的訓練調參更是困難且復雜。在這個過程中,我們需要去嘗試不同的學習率、初始化參數方法(例如Xavier初始化)等方式來幫助我們的模型加速收斂。深度神經網絡之所以如此難訓練,其中一個重要原因就是網絡中層與層之間存在高度的關聯性與耦合性。下圖是一個多層的神經網絡,層與層之間采用全連接的方式進行連接。
我們規定左側為神經網絡的底層,右側為神經網絡的上層。那么網絡中層與層之間的關聯性會導致如下的狀況:隨着訓練的進行,網絡中的參數也隨着梯度下降在不停更新。一方面,當底層網絡中參數發生微弱變化時,由於每一層中的線性變換與非線性激活映射,這些微弱變化隨着網絡層數的加深而被放大(類似蝴蝶效應);另一方面,參數的變化導致每一層的輸入分布會發生改變,進而上層的網絡需要不停地去適應這些分布變化,使得我們的模型訓練變得困難。上述這一現象叫做Internal Covariate Shift。
1.2 什么是Internal Covariate Shift
Batch Normalization的原論文作者給了Internal Covariate Shift一個較規范的定義:在深層網絡訓練的過程中,由於網絡中參數變化而引起內部結點數據分布發生變化的這一過程被稱作Internal Covariate Shift。
這句話該怎么理解呢?我們同樣以1.1中的圖為例,我們定義每一層的線性變換為 ,其中
代表層數;非線性變換為
,其中
為第
層的激活函數。
隨着梯度下降的進行,每一層的參數 與
都會被更新,那么
的分布也就發生了改變,進而
也同樣出現分布的改變。而
作為第
層的輸入,意味着
層就需要去不停適應這種數據分布的變化,這一過程就被叫做Internal Covariate Shift。
1.3 Internal Covariate Shift會帶來什么問題?
(1)上層網絡需要不停調整來適應輸入數據分布的變化,導致網絡學習速度的降低
我們在上面提到了梯度下降的過程會讓每一層的參數 和
發生變化,進而使得每一層的線性與非線性計算結果分布產生變化。后層網絡就要不停地去適應這種分布變化,這個時候就會使得整個網絡的學習速率過慢。
(2)網絡的訓練過程容易陷入梯度飽和區,減緩網絡收斂速度
當我們在神經網絡中采用飽和激活函數(saturated activation function)時,例如sigmoid,tanh激活函數,很容易使得模型訓練陷入梯度飽和區(saturated regime)。隨着模型訓練的進行,我們的參數 會逐漸更新並變大,此時
就會隨之變大,並且
還受到更底層網絡參數
的影響,隨着網絡層數的加深,
很容易陷入梯度飽和區,此時梯度會變得很小甚至接近於0,參數的更新速度就會減慢,進而就會放慢網絡的收斂速度。
對於激活函數梯度飽和問題,有兩種解決思路。第一種就是更為非飽和性激活函數,例如線性整流函數ReLU可以在一定程度上解決訓練進入梯度飽和區的問題。另一種思路是,我們可以讓激活函數的輸入分布保持在一個穩定狀態來盡可能避免它們陷入梯度飽和區,這也就是Normalization的思路。
1.4 我們如何減緩Internal Covariate Shift?
要緩解ICS的問題,就要明白它產生的原因。ICS產生的原因是由於參數更新帶來的網絡中每一層輸入值分布的改變,並且隨着網絡層數的加深而變得更加嚴重,因此我們可以通過固定每一層網絡輸入值的分布來對減緩ICS問題。
(1)白化(Whitening)
白化(Whitening)是機器學習里面常用的一種規范化數據分布的方法,主要是PCA白化與ZCA白化。白化是對輸入數據分布進行變換,進而達到以下兩個目的:
- 使得輸入特征分布具有相同的均值與方差。其中PCA白化保證了所有特征分布均值為0,方差為1;而ZCA白化則保證了所有特征分布均值為0,方差相同;
- 去除特征之間的相關性。
通過白化操作,我們可以減緩ICS的問題,進而固定了每一層網絡輸入分布,加速網絡訓練過程的收斂(LeCun et al.,1998b;Wiesler&Ney,2011)。
(2)Batch Normalization提出
既然白化可以解決這個問題,為什么我們還要提出別的解決辦法?當然是現有的方法具有一定的缺陷,白化主要有以下兩個問題:
- 白化過程計算成本太高,並且在每一輪訓練中的每一層我們都需要做如此高成本計算的白化操作;
- 白化過程由於改變了網絡每一層的分布,因而改變了網絡層中本身數據的表達能力。底層網絡學習到的參數信息會被白化操作丟失掉。
既然有了上面兩個問題,那我們的解決思路就很簡單,一方面,我們提出的normalization方法要能夠簡化計算過程;另一方面又需要經過規范化處理后讓數據盡可能保留原始的表達能力。於是就有了簡化+改進版的白化——Batch Normalization。
2 Batch Normalization
2.1 思路
既然白化計算過程比較復雜,那我們就簡化一點,比如我們可以嘗試單獨對每個特征進行normalizaiton就可以了,讓每個特征都有均值為0,方差為1的分布就OK。
另一個問題,既然白化操作減弱了網絡中每一層輸入數據表達能力,那我就再加個線性變換操作,讓這些數據再能夠盡可能恢復本身的表達能力就好了。
因此,基於上面兩個解決問題的思路,作者提出了Batch Normalization,下一部分來具體講解這個算法步驟。
2.2 算法
在深度學習中,由於采用full batch的訓練方式對內存要求較大,且每一輪訓練時間過長;我們一般都會采用對數據做划分,用mini-batch對網絡進行訓練。因此,Batch Normalization也就在mini-batch的基礎上進行計算。
2.2.1 參數定義
我們依舊以下圖這個神經網絡為例。我們定義網絡總共有 層(不包含輸入層)並定義如下符號:
參數相關:
:網絡中的層標號
:網絡中的最后一層或總層數
:第
層的維度,即神經元結點數
:第
層的權重矩陣,
:第
層的偏置向量,
:第
層的線性計算結果,
:第
層的激活函數
:第
層的非線性激活結果,
樣本相關:
:訓練樣本的數量
:訓練樣本的特征數
:訓練樣本集,
(注意這里
的一列是一個樣本)
:batch size,即每個batch中樣本的數量
:第
個mini-batch的訓練數據,
,其中
2.2.2 算法步驟
介紹算法思路沿襲前面BN提出的思路來講。第一點,對每個特征進行獨立的normalization。我們考慮一個batch的訓練,傳入m個訓練樣本,並關注網絡中的某一層,忽略上標 。
我們關注當前層的第 個維度,也就是第
個神經元結點,則有
。我們當前維度進行規范化:
其中是為了防止方差為0產生無效計算。
下面我們再來結合個具體的例子來進行計算。下圖我們只關注第 層的計算結果,左邊的矩陣是
線性計算結果,還未進行激活函數的非線性變換。此時每一列是一個樣本,圖中可以看到共有8列,代表當前訓練樣本的batch中共有8個樣本,每一行代表當前
層神經元的一個節點,可以看到當前
層共有4個神經元結點,即第
層維度為4。我們可以看到,每行的數據分布都不同。
對於第一個神經元,我們求得 ,
(其中
),此時我們利用
對第一行數據(第一個維度)進行normalization得到新的值
。同理我們可以計算出其他輸入維度歸一化后的值。如下圖:
通過上面的變換,我們解決了第一個問題,即用更加簡化的方式來對數據進行規范化,使得第 層的輸入每個特征的分布均值為0,方差為1。
如同上面提到的,Normalization操作我們雖然緩解了ICS問題,讓每一層網絡的輸入數據分布都變得穩定,但卻導致了數據表達能力的缺失。也就是我們通過變換操作改變了原有數據的信息表達(representation ability of the network),使得底層網絡學習到的參數信息丟失。另一方面,通過讓每一層的輸入分布均值為0,方差為1,會使得輸入在經過sigmoid或tanh激活函數時,容易陷入非線性激活函數的線性區域。
因此,BN又引入了兩個可學習(learnable)的參數 與
。這兩個參數的引入是為了恢復數據本身的表達能力,對規范化后的數據進行線性變換,即
。特別地,當
時,可以實現等價變換(identity transform)並且保留了原始輸入特征的分布信息。
通過上面的步驟,我們就在一定程度上保證了輸入數據的表達能力。
以上就是整個Batch Normalization在模型訓練中的算法和思路。
補充: 在進行normalization的過程中,由於我們的規范化操作會對減去均值,因此,偏置項可以被忽略掉或可以被置為0,即
![]()
2.2.3 公式
對於神經網絡中的第 層,我們有:
3 測試階段如何使用Batch Normalization?
我們知道BN在每一層計算的 與
都是基於當前batch中的訓練數據,但是這就帶來了一個問題:我們在預測階段,有可能只需要預測一個樣本或很少的樣本,沒有像訓練樣本中那么多的數據,此時
與
的計算一定是有偏估計,這個時候我們該如何進行計算呢?
利用BN訓練好模型后,我們保留了每組mini-batch訓練數據在網絡中每一層的 與
。此時我們使用整個樣本的統計量來對Test數據進行歸一化,具體來說使用均值與方差的無偏估計:
得到每個特征的均值與方差的無偏估計后,我們對test數據采用同樣的normalization方法:
另外,除了采用整體樣本的無偏估計外。吳恩達在Coursera上的Deep Learning課程指出可以對train階段每個batch計算的mean/variance采用指數加權平均來得到test階段mean/variance的估計。
4 Batch Normalization的優勢
Batch Normalization在實際工程中被證明了能夠緩解神經網絡難以訓練的問題,BN具有的有事可以總結為以下三點:
(1)BN使得網絡中每層輸入數據的分布相對穩定,加速模型學習速度
BN通過規范化與線性變換使得每一層網絡的輸入數據的均值與方差都在一定范圍內,使得后一層網絡不必不斷去適應底層網絡中輸入的變化,從而實現了網絡中層與層之間的解耦,允許每一層進行獨立學習,有利於提高整個神經網絡的學習速度。
(2)BN使得模型對網絡中的參數不那么敏感,簡化調參過程,使得網絡學習更加穩定
在神經網絡中,我們經常會謹慎地采用一些權重初始化方法(例如Xavier)或者合適的學習率來保證網絡穩定訓練。
當學習率設置太高時,會使得參數更新步伐過大,容易出現震盪和不收斂。但是使用BN的網絡將不會受到參數數值大小的影響。例如,我們對參數 進行縮放得到
。對於縮放前的值
,我們設其均值為
,方差為
;對於縮放值
,設其均值為
,方差為
,則我們有:
,
我們忽略 ,則有:
注:公式中的是當前層的輸入,也是前一層的輸出;不是下標啊旁友們!
我們可以看到,經過BN操作以后,權重的縮放值會被“抹去”,因此保證了輸入數據分布穩定在一定范圍內。另外,權重的縮放並不會影響到對 的梯度計算;並且當權重越大時,即
越大,
越小,意味着權重
的梯度反而越小,這樣BN就保證了梯度不會依賴於參數的scale,使得參數的更新處在更加穩定的狀態。
因此,在使用Batch Normalization之后,抑制了參數微小變化隨着網絡層數加深被放大的問題,使得網絡對參數大小的適應能力更強,此時我們可以設置較大的學習率而不用過於擔心模型divergence的風險。
(3)BN允許網絡使用飽和性激活函數(例如sigmoid,tanh等),緩解梯度消失問題
在不使用BN層的時候,由於網絡的深度與復雜性,很容易使得底層網絡變化累積到上層網絡中,導致模型的訓練很容易進入到激活函數的梯度飽和區;通過normalize操作可以讓激活函數的輸入數據落在梯度非飽和區,緩解梯度消失的問題;另外通過自適應學習 與
又讓數據保留更多的原始信息。
(4)BN具有一定的正則化效果
在Batch Normalization中,由於我們使用mini-batch的均值與方差作為對整體訓練樣本均值與方差的估計,盡管每一個batch中的數據都是從總體樣本中抽樣得到,但不同mini-batch的均值與方差會有所不同,這就為網絡的學習過程中增加了隨機噪音,與Dropout通過關閉神經元給網絡訓練帶來噪音類似,在一定程度上對模型起到了正則化的效果。
另外,原作者通過也證明了網絡加入BN后,可以丟棄Dropout,模型也同樣具有很好的泛化效果。
摘自:https://zhuanlan.zhihu.com/p/34879333