吳恩達深度學習 反向傳播(Back Propagation)公式推導技巧


由於之前看的深度學習的知識都比較零散,補一下吳老師的課程希望能對這塊有一個比較完整的認識。課程分為5個部分(粗體部分為已經看過的):

  1. 神經網絡和深度學習
  2. 改善深層神經網絡:超參數調試、正則化以及優化
  3. 結構化機器學習項目
  4. 卷積神經網絡
  5. 序列模型

第 1 部分講的是神經網絡的基礎,從邏輯回歸到淺層神經網絡再到深層神經網絡。

一直感覺反向傳播(Back Propagation,BP)是這部分的重點,但是當時看的比較匆忙,有些公式的推導理解的不深刻,現在重新回顧一下,一是幫助自己梳理思路加深理解,二是記錄下來以免遺忘。

 

 

1. 符號規定

            圖1 神經網絡示意圖1

一般計算神經網絡層數時不包括輸入層,因此圖1中的網絡層數 L 為4;

n[l] 表示第 l 層的神經元的數量,n[1] = n[2] = 5,n[3] = 3, n[4] =1,n[0] = nx = 3;

z[l] = W[l]·a[l-1] + b[l],w[l] 表示第 l 層的權重,b[l] 表示偏置;

a[l] 表示第 l 層中通過激活函數 g[l] 激活后的值,表示如下:a[l] = g[l](z[l])。

 

 

2. 核對矩陣維數

吳恩達老師推薦的小技巧,通過核對矩陣的維數可以有效地判斷代碼是否有錯。核對矩陣維數對后面的反向傳播公式的推導很有幫助

                 圖2 神經網絡示意圖2

舉個例子:z[1] = W[1]·x+ b[1]

從圖2可以看出:x 的維度是 (2,1),且 z[1] 的維度是 (3,1),由於等式兩邊維度一致,因此可以推出 W[1] 的維度為 (3,2),且 b[1] 也為(3,1)。從正面看,因為第 1 層有 3 個神經元,且有 2 個輸入,因此每個神經元中的參數要分別與兩個輸入相乘,也很容易得出 W[1] 的維度。同理可以推出后面層的參數的維度,總結規律是:

W[l] = (n[l],n[l-1])

a[l] = (n[l],1)

z[l] = b[l] = (n[l],1)

dx 和 x 的維度相同

若有 m 個樣本,將公式向量化之后只需將 a[l] 和 z[l] 改為大寫,並將 1 改為 m 即可(對b,Python的廣播機制將其維數從 1 變為 m )。

 

 

3. 前向傳播和反向傳播

3.1 前向傳播

Input:a[l-1]

Output:a[l], cache(z[l]) (or W[l], b[l])

FP 的兩個公式,比較簡單,直接代入即可(主要根據這兩個公式推導BP):

z[l] = W[l]·a[l-1] + b[l]  ---------  ①

a[l] = g[l](z[l]------------------- ②

 

3.2 反向傳播

Input:da[l]

Output:da[l-1], dW[l], db[l]

BP的公式:

1. 首先求dz[l],由公式②,dz[l] = da[l]*g[l]'(z[l]),根據鏈式求導法則得出,因為*是元素對應相乘,所以兩者順序對結果不影響。

2. 再求dW[l],由公式①,dW[l] = dz[l]·a[l-1]T,因為乘積為點乘,因此兩者順序影響結果。此時,我們可以分析矩陣的維度來判斷順序以及是否要轉置。dW[l] 為 (n[l],n[l-1]),dz[l]為 (n[l],1),a[l-1]為 (n[l-1],1),因此,要得到 dW[l] 的維度,應該將 dz[l] 放在前,並與a[l-1]T作點積運算。(注:吳恩達老師在講課時,寫的是a[l-1],我個人認為此處是筆誤,歡迎大家討論)

3. 同樣根據公式①,容易得出:db[l] = dz[l]

4. 最后,根據公式①,da[l-1] = W[l]T·dz[l],da[l-1] 的維度為 (n[l-1],1),W[l] 的維度為 (n[l],n[l-1]),dz[l]為 (n[l],1),顯然需要將W[l]轉置再與dz[l]作點積。

這樣我們就得到的 Output 的三個值。

 

 

上面是我總結的關於方向傳播中公式推導的一些技巧。可能是剛接觸深度學習,對這些矩陣的維數以及是否要轉置等等還不敏感,在找相關資料時也沒有詳細的解釋(可能是因為太簡單?...),於是自己梳理了一下。當然如果大家有更好的方法,望不吝賜教!

 

參考資料:http://mooc.study.163.com/course/2001281002?tid=2001392029#/info

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM