為什么要使用backpropagation?
梯度下降不用多說,如果不清楚的可以參考梯度下降算法。
神經網絡的參數集合theta,包括超級多組weight和bais。
要使用梯度下降,就需要計算每一個參數的梯度,但是神經網絡常常有數以萬計,甚至百萬的參數,所以需要使用backpropagation來高效地計算梯度。
backpropagation的推導
backpropagation背后的原理其實很簡單,就是求導的鏈式法則。
我們從上面的公式開始推導。以其中一個神經元為例。
如上面的紅框中所示,根據鏈式法則,l對w的偏導數,等於z對w的偏導數乘以l對z的偏導數。
l對w的梯度可以分為兩部分:
前向傳播:對所有參數求梯度;
后向傳播:對所有激活函數的輸入z求梯度;
前向傳播的梯度求法簡單,就前一層的輸入z對w求偏導數,直接求出就是對應的輸入xi。
只要知道了激活函數的輸出值,就可以輕易算出z/w的梯度,這個過程就是前向傳播。
后向傳播比較復雜,需要再使用鏈式法則,如紅框中所示。l/z的梯度分解為a/z和l/a的梯度。
z對應當前節點的輸入,a對應當前節點的輸出。
a對z的導數圖像如上所示,現在關鍵就是求l對a的偏導數。
為了求出l對a的偏導數,繼續使用鏈式法則,關聯上后面的兩個神經元。
a通過z’和z''間接影響l,l/a的梯度應該是它所連接的所有神經元的梯度之和,不止是上面說的兩項。
z'/a和z''/a的偏導數根據前向傳播計算,分別是w3和w4.
現在問題就轉化成了,求紅框中的兩個問號的梯度/
現在假設兩個問號梯度已知,就可以求出之前l對z的梯度了。
這樣看上去有形成了一個新的網絡,一個新的neural,輸入是l/z'和l/z''的梯度,分別乘上對應權重w3,w4,
經過激活函數(乘以sigma(z)的導數)的作用,輸出l/z的梯度。
現在來看看怎么可以求出l對z的梯度。
第一種情況:當z‘和z’‘為輸出層時。根據鏈式法則,y/z的梯度可以根據對應的激活函數算出了,l/y的梯度是根據Cost function算出來的,這樣問題就解決了。
第二種情況:不是輸出層。就是說還有后續的神經元節點連接,往后繼續使用鏈式法則求導,直至輸出層。
循環計算l對z的梯度,直到輸出層,出現case1的情況,問題也就解決了。
所以,我們就可以從輸出層開始,反向計算l對每層z的梯度,在結合前向傳播得到的梯度,就可以計算出梯度下降所需的梯度了。
而且,反向傳播的復雜度和前向傳播是一樣的,這樣就大大提升了梯度計算的效率。后一層的梯度,乘以相應的w,相加再乘上σ‘(z),就得到了當前層的l/z的梯度。
最后結果就是這樣的: