論文筆記:Deep Residual Learning


之前提到,深度神經網絡在訓練中容易遇到梯度消失/爆炸的問題,這個問題產生的根源詳見之前的讀書筆記。在 Batch Normalization 中,我們將輸入數據由激活函數的收斂區調整到梯度較大的區域,在一定程度上緩解了這種問題。不過,當網絡的層數急劇增加時,BP 算法中導數的累乘效應還是很容易讓梯度慢慢減小直至消失。這篇文章中介紹的深度殘差 (Deep Residual) 學習網絡可以說根治了這種問題。下面我按照自己的理解淺淺地水一下 Deep Residual Learning 的基本思想,並簡單介紹一下深度殘差網絡的結構。

基本思想

回到最開始的問題,為什么深度神經網絡會難以訓練?根源在於 BP 的時候我們需要逐層計算導數並將這些導數相乘。這些導數如果太小,梯度就容易消失,反之,則會爆炸。我們沒法從 BP 算法的角度出發讓這個相乘的導數鏈消失,因此,可行的方法就是控制每個導數的值,讓它們盡量靠近 1,這樣,連乘后的結果不會太小,也不會太大。

現在,我們就從導數入手,看看如何實現上面的要求。由於梯度消失的問題比梯度爆炸更常見,因此只針對梯度消失這一點進行改進。

假設我們理想中想讓網絡學習出來的函數是 \(F(x; {W_i})\),但由於它的導數 \(\frac{\partial F}{\partial x}\) 太小,所以訓練的時候梯度就消失了。所謂太小,就是說 \(\frac{\partial F}{\partial x} \approx 0\),那么,我們何不在這個導數的基礎上加上 1 或者減去 1,這樣梯度不就變大了嗎?(這里的 1 是為了滿足之前提到的梯度靠近 1 這一要求,事實上,只要能防止梯度爆炸,其他數值也是可以的,不過作者在之后的實驗中證明,1 的效果最好)

按照這種思路,我們現在想構造一個新的函數,讓它的導數等於 \(\frac{\partial F}{\partial x}+1\)。由這個導數反推回去,很自然地就得到一個我們想要的函數:\(H(x)=F(x)+x\),它的導數為:\(\frac{\partial H}{\partial x} = \frac{\partial F}{\partial x}+1\)。這個時候你可能會想,如果將原來的 \(F(x)\) 變成 \(H(x)\),那網絡想要提取的特征不就不正確了嗎,這個網絡還有什么用?不錯,我們想要的最終函數是 \(F(x; {W_i})\),這個時候再加個 \(x\) 上去,結果肯定不是我們想要的。但是,為什么一定要讓網絡學出 \(F(x; {W_i})\)?為什么不用 \(H(x)\) 替換原本的 \(F(x;{W_i})\),而將網絡學習的目標調整為:\(F(x)=H(x)-x\)?要知道,神經網絡是可以近似任何函數的,只要讓網絡學出這個新的 \(F(x)\),那么我們自然也就可以通過 \(H(x)=F(x)+x\) 得到最終想要的函數形式。作者認為,通過這種方式學習得到的 \(H(x)\) 函數,跟當初直接讓網絡學習出的 \(F(x, {W_i})\),效果上是等價的,但前者卻更容易訓練。

==================== UPDATE 2018.1.23 =====================

時隔幾個月重新看這篇文章,發現當初的理解存在一個巨大的問題,在此,對那些被我誤導的同學深深道歉🙇。

這里的問題在於,BP 算法中我們要計算的是參數 \(W\)\(b\) 的導數,所以導數的形式不應該是 \(\frac{\partial F}{\partial x}\),而是 \(\frac{F}{W_i}\)(bias 同理)。這樣一來,我之前對殘差網絡改進梯度消失問題的理解就錯了。不過,我依然固執地認為,殘差學習是為了解決深度網絡中梯度消失的問題,只是要換種方式理解。

對於最簡單的神經網絡(假設退化成一條鏈):

\(C\) 是網絡的 loss 函數,\(z^l\) 表示第 l 層激活函數的輸入,\(a^l\) 表示第 l 層激活函數的輸出(\(a^0\) 就是網絡最開始的輸入了),則 \(a^l = \sigma(z^l)\)\(z^l=a^{l-1}*w^l\)\(W^l\) 是第 l 層的權重參數,簡單起見,不考慮 bias)。\(\delta^l\) 是第 l 層的誤差。

根據 BP 算法,先計算誤差項:

\[\delta^3=\frac{\partial C}{\partial a^3}\frac{\partial a^3}{\partial z^3}=\frac{\partial C}{\partial a^3}\sigma'(z^3) \\ \delta^2=\delta^3 \sigma'(z^2)w^3=\frac{\partial C}{\partial a^3}\sigma'(z^3)\sigma'(z^2)w^3 \\ \delta^1=\delta^2\sigma'(z^1)w^2=\frac{\partial C}{\partial a^3}\sigma'(z^3)\sigma'(z^2)w^3\sigma'(z^1)w^2 \]

然后根據誤差項計算 \(w\) 的導數:

\[\frac{\partial C}{\partial w^3}=\delta^3a^2 \\ \frac{\partial C}{\partial w^2}=\delta^2a^1 \\ \frac{\partial C}{\partial w^1}=\delta^1a^0 \]

一般來說,梯度的消失是這些項的累乘造成的:\(\sigma'(z^3)\sigma'(z^2)w^3\sigma'(z^1)w^2\)(因為 \(\sigma'(z^l)\)\(w^l\) 一般都小於 1)。

那殘差網絡做了那些修改呢?其實就是簡單地在激活函數的輸出后面,加入上一層的輸入:

假設原本的網絡是要學習一個 \(H(x)\) 函數,那現在這個網絡依然是要學習 \(H(x)\)。只不過,原本的網絡要學習的是整個 \(H(x)\),而殘差網絡中,和原本網絡相同的那部分結構,要學習的就只是 \(H(x)-x\)。換句話說,它要學習的東西只是一個微小的變化,因此訓練起來相對更容易一些。

另一方面,我們沿用之前對導數的分析思路,看看殘差網絡的梯度會發生什么變化。

首先,殘差網絡的前向傳播發生了變化:

\[z^1=a^0 \\ a^1=\sigma(z^1)+a^0 \\ z^2=a^1w^2 \\ a^2=\sigma(z^2)+a^1 \\ z^3=a^2w^3 \\ a^3=\sigma(z^3)+a^2 \]

反向傳播計算的誤差項為:

\[\delta^3=\frac{\partial C}{\partial z^3}=\frac{\partial C}{\partial a^3}\frac{\partial a^3}{\partial z^3}=\frac{\partial C}{\partial a^3}[\sigma'(z^3)+\frac{\partial a^2}{\partial z^3}] \\ \delta^2=\delta^3 w^3 \frac{\partial a^2}{\partial z^2}=\frac{\partial C}{\partial a^3}[\sigma'(z^3)+\frac{\partial a^2}{\partial z^3}]w^3 [\sigma'(z^2)+\frac{\partial a^1}{\partial z^2}] \\ \vdots \]

由於 \(z^3=a^2w^3\),所以 \(a^2=\frac{z^3}{w^3}\),故 \(\frac{\partial a^2}{\partial z^3}=\frac{1}{w^3}\),同理 \(\frac{\partial a^1}{\partial z^2}=\frac{1}{w^2}\)。代入到上式中就變成:

\[\delta^3=\frac{\partial C}{\partial a^3}[\sigma'(z^3)+\frac{1}{w^3}] \\ \delta^2=\frac{\partial C}{\partial a^3}[\sigma'(z^3)+\frac{1}{w^3}]w^3 [\sigma'(z^2)+\frac{1}{w^2}]=\frac{\partial C}{\partial a^3}[\sigma'(z^3)w^3+1] [\sigma'(z^2)+\frac{1}{w^2}] \\ \vdots \]

對比之前沒加殘差結構的網絡,這個新的網絡結構中,誤差項 \(\delta^l\) 減小為 0 的可能性降低了。以 \(\delta^2\) 為例,原本的 \(\delta^2=\frac{\partial C}{\partial a^3}\sigma'(z^3)\sigma'(z^2)w^3\),而現在,連乘的項變成了 \([\sigma'(z^3)w^3+1]\)\([\sigma'(z^2)+\frac{1}{w^2}]\),由於 \(\sigma'(z^l)\)\(w^l\) 一般都小於 1,所以這兩項的值會略大於 1,這樣,無論連乘多少項,梯度都不會縮小到 0。

**==================================================**

上面所說的 \(F(x)=H(x)-x\) 就是所謂的殘差 (residual),而式子內的 \(x\) 在論文中被稱為 Identity Mapping,因為 x 可以看作是由自己到自己的映射函數。基於此,我們可以得到一個新的網絡結構,如同開篇的圖片所示,這個網絡結構跟普通的網絡結構類似,但在輸出那里多加了一個 Identity Mapping,相當於在網絡原有輸出的基礎上加一個 x,這樣便得到我們想要的函數 \(H(x)\)。作者將這種相加稱為 shortcut connection,意思就是說,\(x\) 沒有經過中間的變換操作,像「短路」一樣直接跳到輸出那里和 \(F(x)\) 相加。需要注意的是,這個網絡結構的輸入並不一定是原始的數據,它可以是前面一層網絡的輸出結果。同理,網絡的輸出也可以繼續輸入到后面層的網絡中。

我們用一個式子來表示這個網絡:\(y=F(x,{W_i})+x\),其中 \(F(x,{W_i})=W_2 \sigma(W_1x)\) (這里忽略了 bias)。在論文中,這里的 \(\sigma\) 函數采用的是 ReLu。得到 \(y\) 后,作者又對其做了一次 ReLu 操作,然后再進入下一層網絡。

Talk is cheap,show you the code(這里用 tensorflow 表示一下上圖那個網絡結構):

# 假設 x 是該網絡結構的輸入
c1 = tf.layers.conv2d(x, kernel, [w, h], strides=[s,s])
b1 = tf.layers.batch_normalization(c1, training=is_training)
h1 = tf.nn.relu(b1)
c2 = tf.layers.conv2d(h1, kernel, [w, h], strides=[s,s])
b2 = tf.layers.batch_normalization(c2, training=is_training)
r = b2 + x
y = tf.nn.relu(r)

因為 \(x\)\(F(x)\) 是直接相加的,所以它們的維度必須相同,不同的情況下,需要對 \(x\) 的維度進行調整。可以通過做一次線性變換將它投影到所需的維度空間,也可以用其他簡單粗暴的方法。比如,當維度太高時,可以用 pooling 的方法降低維度。而維度較低時,作者在實驗中則是直接補 0 來擴展維度。

深度殘差網絡

好了,了解了殘差網絡的基本思路和簡單的網絡結構后,下面我們可以將它拓展到更深的網絡結構中。

下圖是一個普通的網絡和改造后的殘差網絡:

左邊的網絡是沒有添加殘差層的網絡,作者稱它為 plain network,意思就是這個網絡很「平」(每次看到這個名字我總是會浮出一些邪惡的想法~囧~)。右邊的則是一個完整的深度殘差網絡,它其實就是由前文所說的小的網絡結構組成的,虛線表示要對 \(x\) 的維度進行擴增。作者在兩個網絡中都加了 Batch Normalization(具體加在卷積層之后,激活層之前),我想目的大概是要在之后的實驗中凸顯 residual learning 優於 BN 的效果吧。

下面分析一下 identity mapping 對殘差網絡所起的作用,通過這個最簡單的映射來了解 residual learning 不同於一般網絡的地方。

首先,給出最通用的網絡結構:

這里其實就是將之前的 \(x\) 換成 \(h(x)\),將最后的 ReLu 換成 \(f(x)\)。因為事實上,\(h(x)\)\(f(x)\) 的形式是很自由的,\(h(x)\) 可以是 \(x\)\(2x\)\(x^2\),只要能防止梯度消失或爆炸即可。而 \(f(x)\) 也可以是其他各種激活函數。

不過,因為我們是要從 identity mapping 着手,所以這里還是令 \(h(x)=x\)\(f(x)=x\)

然后,我們用類推出:

到了這一步,可以發現,在 identity mapping 中,殘差網絡的輸出其實就是在原始輸入 \(x_l\) 的基礎上,加上后面的一堆「殘差」。如果對其求導,則可以得出:

我們發現,導數的形式也很類似,也是最后一層的導數加上前面的一堆「殘差」導數,而這一步是殘差網絡中梯度不容易消失的原因。

作者經過對比實驗發現,identity mapping 的效果要好於其他的 mapping,具體的實驗細節請參考 tutorial 和后續的一篇論文 Identity Mappings in Deep Residual Networks。換句話說,使用 residual network 時,最好用上 identity mapping。

論文中的實驗

實驗部分,我只講一下 ImageNet 的結果。

作者分別用 18 層和 34 層的網絡做了兩組對比實驗(兩組網絡除了殘差外,其他結構相同,並且都加了 BN 層。在對 \(x\) 升維時,直接使用 0 進行 padding,換句話說,殘差網絡的參數和 plain 的一樣。34 層的網絡見上一部分的說明),並分析了它們在 ImageNet 訓練集上的誤差下降情況:

上圖中,左圖是 plain 網絡,右圖是 ResNet。注意,訓練剛開始的時候,ResNet 的誤差下降的速度比 plain 網絡要快,也就是說,殘差網絡的訓練速度快於 plain 網絡。對於 18 層的網絡而言,兩者最終的准確率持平,但對於 34 層的網絡,使用殘差的結果要好於一般的網絡。另外,我們再看看驗證集上的情況:

這個結果表明,當網絡層數不多時,plain 網絡和殘差網絡除了訓練速度不一樣外,對最終的結果影響不大。但如果層數比較深,殘差網絡可以提升准確率。作者在這里提出一個問題:既然我們已經在網絡中加了 BN,那導致 plain 網絡准確率降不下來的原因應該不會是梯度消失。但又會是其他什么原因呢?作者在論文中稱這種問題為 degradation problem,即退化問題。它指的是隨着網絡層數增加,在梯度沒有消失的情況下導致的網絡訓練緩慢或訓練停止的問題。當然啦,按照我自己的理解和猜測,就如這篇文章開篇所講的那樣,梯度消失是由兩個方面導致,而 BN 只是將數據從激活函數的收斂區調整到梯度更大的區域,但導數相乘后的累積效應仍然會使梯度變小,所以才導致這里所說的退化問題。不過具體的原因,還有待進一步研究。

參考


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM