大幅減少GPU顯存占用:可逆殘差網絡(The Reversible Residual Network)


前序:

  Google AI最新出品的論文Reformer 在ICLR 2020會議上獲得高分,論文中對當前暴熱的Transformer做兩點革新:一個是局部敏感哈希(LSH);一個是可逆殘差網絡代替標准殘差網絡。本文主要介紹變革的第二部分,可逆殘差網絡。先從神經網絡的反向傳播講起,然后是標准殘差網絡,最后自然過渡到可逆殘差網絡。讀完本文相信你會對神經網絡的架構發展有一個非常清晰的認識。

一、背景介紹

  當前所有的神經網絡都采用反向傳播的方式來訓練,反向傳播算法需要存儲網絡的中間結果來計算梯度,而且其對內存的消耗與網絡單元數成正比。這也就意味着,網絡越深越廣,對內存的消耗越大,這將成為很多應用的瓶頸。由於GPU的顯存受限,使得網絡結構難以達到最優,因為有些網絡結構可能達到上千層的深度。如果采用並行GPU的話,價格既昂貴又比較復雜,同時也不適合個人研究。


  上面是torchsummary截圖,forword和bacword pass size就是需要保存的中間變量大小,可以看出這部分占據了大部分顯存。如果不存儲中間層結果,那么就可以大幅減少GPU的顯存占用,有助於訓練更深更廣的網絡。多倫多大學的Aidan N.Gomez和Mengye Ren提出了可逆殘差神經網絡,當前層的激活結果可由下一層的結果計算得出,也就是如果我們知道網絡層最后的結果,就可以反推前面每一層的中間結果。這樣我們只需要存儲網絡的參數和最后一層的結果即可,激活結果的存儲與網絡的深度無關了,將大幅減少顯存占用。令人驚訝的是,實驗結果顯示,可逆殘差網絡的表現並沒有顯著下降,與之前的標准殘差網絡實驗結果基本旗鼓相當。

  如果你已經對很多計算細節遺忘不清楚了,沒關系,下面我們將先從BP反向傳播、標准殘差網絡一步步講起,本文的目的就是要帶你從頭到尾搞清楚。首先我們溫故一下多元復合函數求導公式

二、神經網絡的反向傳播(BP)

符號表示:

X1,X2,X3:表示3個輸入層節點

Wtji:表示從t-1層到t層的權重參數,j表示t層的第j個節點,i表示t-1層的第i個節點

ati:表示t層的第i個激活后輸出結果

g(x):表示激活函數

正向傳播計算過程

<隱藏層>

<輸出層>

反向傳播

以單個樣本為例,假設輸入向量是[x1,x2,x3],目標輸出值是[y1,y2],代價函數用L表示。反向傳播的總體原理就是根據總體輸出誤差,反向傳播回網絡,通過計算每一層節點的梯度,利用梯度下降法原理,更新每一層的網絡權重w和偏置b,這也是網絡學習的過程。誤差反向傳播的優點就是可以把繁雜的導數計算以數列遞推的形式來表示, 簡化了計算過程。

 以平方誤差來計算反向傳播的過程,代價函數表示如下:

根據導數的鏈式法則反向求解隱藏->輸出層、輸入層->隱藏層的權重表示:

 引入新的誤差求導表示形式,稱為神經單元誤差

l=2,3表示第幾層,j表示某一層的第幾個節點。替換表示后如下:

所以我們可以歸納出一般的計算公式:

從上述公式可以看出,如果神經單元誤差δ可以求出來,那么總誤差對每一層的權重w和偏置b的偏導數就可以求出來,接下來就可以利用梯度下降法來優化參數了

求解每一層的δ:

 輸出層

隱藏層

 

也就是說,我們根據輸出層的神經誤差單元δ就可以直接求出隱藏層的神經誤差單元,進而省去了隱藏層的繁雜的求導過程,我們可以得出更一般的計算過程:

從而得出l層神經單元誤差和l+1層神經單元誤差的關系。這就是誤差反向傳播算法,只要求出輸出層的神經單元誤差,其它層的神經單元誤差就不需要計算偏導數了,而可以直接通過上述公式得出

 

 三、殘差網絡(Residual Network)

殘差網絡主要可以解決兩個問題:1)梯度消失問題;2)網絡退化問題。其結構如下圖

上述結構就是一個兩層網絡組成的殘差塊,殘差塊可以由2、3層甚至更多層組成,但是如果是一層的,就變成線性變換了,沒什么意義了。上述圖可以寫成公式如下:

F(x)=W* ReLU(W* X)

所以在第二層進入激活函數ReLU之前F(x)+X組成新的輸入,也叫恆等映射,就是在這個殘差塊輸入是X的情況下輸出依然是X,這樣其目標就是學習讓F(X)=0。

為什么要額外加一個X呢,而不是讓模型直接學習F(x)=X?

  因為讓F(x)=0比較容易,初始化參數W非常小接近0,就可以讓輸出接近0,同時輸出如果是負數,經過第一層Relu后輸出依然0,都能使得最后的F(X)=0,也就是有多種情況都可以使得F(x)=0;但是讓F(x)=x確實非常難的,因為參數都必須剛剛好才能使得最后輸出為X。

恆等映射有什么作用?

  恆等映射就可以解決網絡退化的問題,當網絡層數越來越深的時候,網絡的精度卻在下降,也就是說網絡自身存在一個最優的層度結構,太深太淺都能使得模型精度下降。有了恆等映射存在,網絡就能夠自己學習到哪些層是冗余的,就可以無損通過這些層,理論上講再深的網絡都不影響其精度,解決了網絡退化問題。

為什么可以解決梯度消失問題呢?

  以兩個殘差塊的結構實例圖來分析,其中每個殘差塊有2層神經網絡組成,如下圖:

 

假設激活函數ReLU用g(x)函數來表示,樣本實例是[X1,Y1],即輸入是X1,目標值是Y1,損失函數還是采用平方損失函數,則每一層的計算如下: 

 

下面我們對第一個殘差塊的權重參數求導,根據鏈式求導法則,公式如下:

 

我們可以看到求導公式中多了一個+1項,這就將原來的鏈式求導中的連乘變成了連加狀態,可以有效避免梯度消失了

 

四、可逆殘差網絡(Reversible Residual Network)

1)可逆塊結構

 可逆神經網絡將每一層分割成兩部分,分別為x1和x2,每一個可逆塊的輸入是(x1,x2),輸出是(y1,y2)。其結構如下:

正向計算圖示:

公式表示:

                  

 

逆向計算圖示:

公式表示:

                       

其中F和G都是相似的殘差函數,參考上圖殘差網絡。可逆塊的跨距只能為1,也就是說可逆塊必須一個接一個連接,中間不能采用其它網絡形式銜接,否則的話就會丟失信息,並且無法可逆計算了,這點與殘差塊不一樣。如果一定要采取跟殘差塊相似的結構,也就是中間一部分采用普通網絡形式銜接,那中間這部分的激活結果就必須顯式的存起來。

2)不用存儲激活結果的反向傳播

   為了更好地計算反向傳播的步驟,我們修改一下上述正向計算和逆向計算的公式:

  盡管z1和y1的值是相同的,但是兩個變量在圖中卻代表不同的節點,所以在反向傳播中它們的總體導數是不一樣的。Z1的導數包含通過y2產生的間接影響,而y2的導數卻不受y2的任何影響。

  在反向傳播計算流程中,先給出最后一層的激活值(y1,y2)和誤差傳播的總體導數(dL/dy1,dL/dy2),然后要計算出其輸入值(x1,x2)和對應的導數(dL/dx1,dL/dx2),以及殘差函數F和G中權重參數的總體導數,求解步驟如下:

 3)計算開銷

  一個N個連接的神經網絡,正向計算的理論加乘開銷為N,反向傳播求導的理論加乘開銷為2N(反向求導包含復合函數求導連乘),而可逆網絡多一步需要反向計算輸入值的操作,所以理論計算開銷為4N,比普通網絡開銷約多出33%左右。但是在實際操作中,正向和反向的計算開銷在GPU上差不多,可以都理解為N。那么這樣的話,普通網絡的整體計算開銷為2N,可逆網絡的整體開銷為3N,也就是多出了約50%。

 

參考論文:The Reversible Residual Network:Backpropagation Without Storing Activations


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM