Batch Normalization 和 Batch Renormalization 前向和反向公式詳細推導


Batch Normalization 和 Batch Renormalization 前向和反向公式詳細推導

一、BN前向傳播

根據論文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推導過程,主要有下面四個公式:

\[\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1}\label{1} \]

\[\delta_B^2=\frac{1}{m}\sum_i^m(x_i-\mu_B)^2\tag{2}\label{2} \]

\[\widehat{x_i}=\frac{x_i-\mu_B}{\sqrt{\delta_B^2+\epsilon}}\tag{3}\label{3} \]

\[y_i=\gamma\widehat{x_i}+\beta\tag{4}\label{4} \]

以MLP為例,假設輸入的mini-batch樣本數為\(m\),則此處的\(x_i,i=1,2,...m\)是第\(i\)個樣本對應的某一層激活值中的一個激活值。也就是說,假設輸入\(m\)個樣本作為一次訓練,其中第\(i\)個樣本輸入網絡后,在\(l\)層得到了\(N\)個激活單元,則\(x_i\)代表其中任意一個激活單元。事實上應該寫為\(x_i^l(n)\)更為直觀。

所以BN實際上就是對第\(l\)層的第\(n\)個激活單元\(x_i^l(n)\)求其在一個batch中的平均值和方差,並對其進行標准歸一化,得到\(\widehat{x_i^l(n)}\),可知歸一化后的m個激活單元均值為0方差為1,一定程度上消除了Internal Covariate Shift,減少了網絡的各層激活值在訓練樣本上的邊緣分布的變化。

二、BN的反向傳播

  • 設前一層的梯度為\(\frac{\partial{L}}{\partial{y_i}}\).
  • 需要計算\(\frac{\partial{L}}{\partial{x_i}},\frac{\partial{L}}{\partial{\gamma}}以及\frac{\partial{L}}{\partial{\beta}}\)

由鏈式法則以及公式\eqref{4}:

\[\frac{\partial{L}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\widehat{x_i} \tag{5} \]

由於對於所有\(i=1,2...m. \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}對\frac{\partial{L}}{\partial{\gamma}}\)均有貢獻,因此一個batch的訓練中將\(\frac{\partial{L}}{\partial{\gamma}}\)定義為:

\[\frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{6}\label{6} \]

同樣有:

\[\frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{7}\label{7} \]

另外,求\(\frac{\partial{L}}{\partial{x_i}}\)過程則較為復雜。根據鏈式法則,以及公式\(\eqref{3}\),將\(\widehat{x_i}\)視為\(g(x_i,\delta_B^2,\mu_B)\)有:

\[\frac{\partial{L}}{\partial{x_i}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(\frac{\partial{\widehat{x_i}}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\delta_B^2}}\frac{\partial{\delta_B^2}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\mu_B}}\frac{\partial{\mu_B}}{\partial{x_i}}) =\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(g_1'+g_2'\frac{\partial{\delta_B^2}}{\partial{x_i}}+g_3'\frac{\partial{\mu_B}}{\partial{x_i}}) \tag{8}\label{8} \]

而因為公式\(\eqref{2}\)可知上式括號中的第二項求偏導可以進一步拆分。(將\(\delta_B^2\)視為\(f(x_i,\mu_B)\)

\[\frac{\partial{\delta_B^2}}{\partial{x_i}}= \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{\delta_B^2}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}}= f_1'+f_2'\frac{\partial{\mu_B}}{\partial{x_i}} \tag{9}\label{9} \]

注意公式\(\eqref{9}\)中的兩個\(\frac{\partial{\delta_B^2}}{\partial{x_i}}\)代表不同的含義。由公式\(\eqref{8},\eqref{9}\)可知,只要求出\(f_1',f_2',g_1',g_2',g_3',\frac{\partial{\mu_B}}{\partial{x_i}},\frac{\partial{y_i}}{\partial{\widehat{x_i}}}\).即可求出\(\frac{\partial{L}}{\partial{x_i}}\).

原論文中將公式\(\eqref{8}\)拆分成如下幾項:

\[\frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\partial{\widehat{x_i}}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\delta_B^2}} \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}} \tag{10}\label{10} \]

其中:

\[\frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.1}\label{10.1} \]

\[\frac{\partial{\widehat{x_i}}}{\partial{x_i}}=g'_1=\frac{1}{\sqrt{\delta_B^2+\epsilon}} \tag{10.2}\label{10.2} \]

\[\frac{\partial{L}}{\partial{\delta_B^2}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \longrightarrow \]

\[\sum_{i=1}^m\frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \tag{10.3}\label{10.3} \]

\[\frac{\partial{\delta_B^2}}{\partial{x_i}}=f'_1=\frac{2(x_i-\mu_B)}{m} \tag{10.4}\label{10.4} \]

\[\frac{\partial{L}}{\partial{\mu_B}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_3+ \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2f'_2 \longrightarrow \]

\[\sum_{i=1}^m( \frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-1}{\sqrt{\delta_B^2+\epsilon}} +\frac{\partial{L}}{\partial{\delta_B^2}}\frac{2(\mu_B-x_i)}{m}) \tag{10.5}\label{10.5} \]

\[\frac{\partial{\mu_B}}{\partial{x_i}}=\frac{1}{m} \tag{10.6}\label{10.6} \]

最終BN的反向過程由公式\(\eqref{6},\eqref{7},\eqref{10}\)給出。

三、Batch Renormalization

參照論文—— Batch Renormalization: Towards Reducing Minibatch Dependence
in Batch-Normalized Models

Batch Renormalization是對傳統BN的優化,該方法保證了train和inference階段的等效性,解決了非獨立同分布和小minibatch的問題。

1、前向

跟原來的公式類似,添加了兩個非訓練參數\(r,d\):

\[\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1.1}\label{1.1} \]

\[\sigma_B=\sqrt{\epsilon+\frac{1}{m}\sum_i^m(x_i-\mu_B)^2}\tag{2.1}\label{2.1} \]

\[\widehat{x_i}=\frac{x_i-\mu_B}{\sigma_B}r+d\tag{3.1}\label{3.1} \]

\[y_i=\gamma\widehat{x_i}+\beta\tag{4.1}\label{4.1} \]

\[r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(\frac{\sigma_B}{\sigma}))\tag{5.1}\label{5.1} \]

\[d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(\frac{\mu_B-\mu}{\sigma}))\tag{6.1}\label{6.1} \]


Update moving averages:

\[\mu:=\mu+\alpha(\mu_B-\mu)\tag{7.1}\label{7.1} \]

\[\sigma:=\sigma+\alpha(\sigma_B-\sigma)\tag{8.1}\label{8.1} \]

Inference:

\[y=\gamma\frac{x-\mu}{\sigma}+\beta\tag{9.1}\label{9.1} \]

相比於之前的BN只在訓練時計算滑動均值與方差,推斷時才使用他們;BRN在訓練和推斷時都用到了滑動均值與方差。

2、反向

反向的推導與BN類似,

\[\frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.11}\label{10.11} \]

\[\frac{\partial{L}}{\partial{\sigma_B}} \longrightarrow\sum_{i=1}^m \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{-r(x_i-\mu_B)}{\sigma_B^2} \tag{10.22}\label{10.22} \]

\[\frac{\partial{L}}{\partial{\mu_B}}\longrightarrow\sum_{i=1}^{m}\frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-r}{\sigma_B} \tag{10.33}\label{10.33} \]

\[\frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{r}{\sigma_B}+ \frac{\partial{L}}{\partial{\sigma_B}} \frac{x_i-\mu_B}{m\sigma_B}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{1}{m} \tag{10.44}\label{10.44} \]

\[\frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{10.55}\label{10.55} \]

\[\frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{10.66}\label{10.66} \]

三、卷積網絡中的BN

​ 上面的推導過程都是基於MLP的。對於卷積網絡而言,BN過程中的m個激活單元被推廣為m幅特征圖像。 假設某一層卷積后的feature map是\([N,H,W,C]\)的張量,其中N表示batch數目,H,W分別表示長和寬,C表示特征通道數。則對卷積網絡的BN操作時,令\(m = N\times H\times W\),也就是說將第\(i\)個batch內某一通道\(c\)上的任意一個特征圖像素點視為\(x_i\),套用上面的BN公式即可。所以對於卷積網絡來說,中間激活層每個通道都對應一組BN參數\(\gamma,\beta\).


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM