如果能二秒內在腦袋里解出下面的問題,本文便結束了。
已知:
,其中
。
求:
,
,
。
到這里,請耐心看完下面的公式推導,無需長久心里建設。
首先,反向傳播的數學原理是“求導的鏈式法則” :
設
和
為
的可導函數,則
。
接下來介紹
- 矩陣、向量求導的維數相容原則
- 利用維數相容原則快速推導反向傳播
- 編程實現前向傳播、反向傳播
- 卷積神經網絡的反向傳播
快速矩陣、向量求導
這一節展示如何使用鏈式法則、轉置、組合等技巧來快速完成對矩陣、向量的求導
一個原則維數相容,實質是多元微分基本知識,沒有在課本中找到下列內容,維數相容原則是我個人總結:
維數相容原則:通過前后換序、轉置 使求導結果滿足矩陣乘法且結果維數滿足下式:
如果
,
,那么
。
利用維數相容原則解上例:
step1:把所有參數當做實數來求導,
,
依據鏈式法則有
,
,![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtKJTdEJTdCJTVDcGFydGlhbCt5JTdEJTNELTIlMjhYdy15JTI5.png)
可以看出除了
,
和
的求導結果在維數上連矩陣乘法都不能滿足。
step2:根據step1的求導結果,依據維數相容原則做調整:前后換序、轉置
依據維數相容原則
,但
中
、
,自然得調整為
;
同理:
,但
中
、
,那么通過換序、轉置我們可以得到維數相容的結果
。
對於矩陣、向量求導:
- “當做一維實數使用鏈式法則求導,然后做維數相容調整,使之符合矩陣乘法原則且維數相容”是快速准確的策略;
- “對單個元素求導、再整理成矩陣形式”這種方式整理是困難的、過程是緩慢的,結果是易出錯的(不信你試試)。
如何證明經過維數相容原則調整后的結果是正確的呢?直覺!簡單就是美...
快速反向傳播
神經網絡的反向傳播求得“各層”參數
和
的導數,使用梯度下降(一階GD、SGD,二階LBFGS、共軛梯度等)優化目標函數。
接下來,展示不使用下標的記法(
,
or
)直接對
和
求導,反向傳播是鏈式法則和維數相容原則的完美體現,對每一層參數的求導利用上一層的中間結果完成。
這里的標號,參考UFLDL教程 - Ufldl
前向傳播:
(公式1)
(公式2)
為第
層的中間結果,
為第
層的激活值,其中第
層包含元素:輸入
,參數
、
,激活函數
,中間結果
,輸出
。
設神經網絡的損失函數為
(這里不給出具體公式,可以是交叉熵、MSE等),根據鏈式法則有:
這里記
,其中
、
可由 公式1 得出,
加轉置符號
是根據維數相容原則作出的調整。
如何求
? 可使用如下遞推(需根據維數相容原則作出調整):
其中
、
。
那么我們可以從最頂層逐層往下,便可以遞推求得每一層的![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNkZWx0YSslNUUlN0IlMjhsJTI5JTdEKyUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtKJTI4VyUyQ2IlMjklN0QlN0IlNUNwYXJ0aWFsK3olNUUlN0IlMjhsJTI5JTdEJTdE.png)
注意:
是逐維求導,在公式中是點乘的形式。
反向傳播整個流程如下:
1) 進行前向傳播計算,利用前向傳播公式,得到隱藏層和輸出層 的激活值。
2) 對輸出層(第
層),計算殘差:
(不同損失函數,結果不同,這里不給出具體形式)
3) 對於
的隱藏層,計算:
4) 計算各層參數
、
偏導數:
編程實現
大部分開源library(如:caffe,Kaldi/src/{nnet1,nnet2})的實現通常把
、
作為一個layer,激活函數
作為一個layer(如:sigmoid、relu、softplus、softmax)。
反向傳播時分清楚該層的輸入、輸出即能正確編程實現,如:
(公式1)
(公式2)
(1)式AffineTransform/FullConnected層,以下是偽代碼:
注: out_diff =
是上一層(Softmax 或 Sigmoid/ReLU的 in_diff)已經求得:
(公式 1-1)
(公式 1-2)
(公式 1-3)
(2)式激活函數層(以Sigmoid為例)
注:out_diff =
是上一層AffineTransform的in_diff,已經求得,
在實際編程實現時,in、out可能是矩陣(通常以一行存儲一個輸入向量,矩陣的行數就是batch_size),那么上面的C++代碼就要做出變化(改變前后順序、轉置,把函數參數的Vector換成Matrix,此時Matrix out_diff 每一行就要存儲對應一個Vector的diff,在update的時候要做這個batch的加和,這個加和可以通過矩陣相乘out_diff*input(適當的轉置)得到。
如果熟悉SVD分解的過程,通過SVD逆過程就可以輕松理解這種通過乘積來做加和的技巧。
丟掉那些下標記法吧!
卷積層求導
卷積怎么求導呢?實際上卷積可以通過矩陣乘法來實現(是否旋轉無所謂的,對稱處理,caffe里面是不是有image2col),當然也可以使用FFT在頻率域做加法。
那么既然通過矩陣乘法,維數相容原則仍然可以運用,CNN求導比DNN復雜一些,要做些累加的操作。具體怎么做還要看編程時選擇怎樣的策略、數據結構。
快速矩陣、向量求導之維數相容大法已成。

![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiaWd0cmlhbmdsZWRvd25fJTdCVyU1RSU3QiUyOGwlMjklN0QlN0RKJTI4VyUyQ2IlMjklM0QlNUNmcmFjJTdCJTVDcGFydGlhbCtKJTI4VyUyQ2IlMjklN0QlN0IlNUNwYXJ0aWFsK3olNUUlN0IlMjhsJTJCMSUyOSU3RCU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyQjElMjklN0QlN0QlN0IlNUNwYXJ0aWFsK1clNUUlN0IlMjhsJTI5JTdEJTdEJTNEJTVDZGVsdGErJTVFJTdCJTI4bCUyQjElMjklN0QlMjhhKyU1RSU3QiUyOGwlMjklN0QlMjklNUVUKw==.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiaWd0cmlhbmdsZWRvd25fJTdCYiU1RSU3QiUyOGwlMjklN0QlN0RKJTI4VyUyQ2IlMjklM0QlNUNmcmFjJTdCJTVDcGFydGlhbCtKJTI4VyUyQ2IlMjklN0QlN0IlNUNwYXJ0aWFsK3olNUUlN0IlMjhsJTJCMSUyOSU3RCU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyQjElMjklN0QlN0QlN0IlNUNwYXJ0aWFsK2IlNUUlN0IlMjhsJTI5JTdEJTdEJTNEJTVDZGVsdGErJTVFJTdCJTI4bCUyQjElMjklN0Q=.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNkZWx0YSslNUUlN0IlMjhsJTI5JTdEJTNEJTVDZnJhYyU3QiU1Q3BhcnRpYWwrSiU3RCU3QiU1Q3BhcnRpYWwreiU1RSU3QiUyOGwlMjklN0QlN0QlM0QlNUNmcmFjJTdCJTVDcGFydGlhbCtKJTdEJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyQjElMjklN0QlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreiU1RSU3QiUyOGwlMkIxJTI5JTdEJTdEJTdCJTVDcGFydGlhbCthJTVFJTdCJTI4bCUyOSU3RCU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCthJTVFJTdCJTI4bCUyOSU3RCU3RCU3QiU1Q3BhcnRpYWwreiU1RSU3QiUyOGwlMjklN0QlN0QlM0QrJTBBJTI4JTI4VyU1RSU3QiUyOGwlMjklN0QlMjklNUUlN0JUJTdEJTVDZGVsdGErJTVFJTdCJTI4bCUyQjElMjklN0QlMjkrJTVDY2RvdCsrZiUyNyUyOHolNUUlN0IlMjhsJTI5JTdEJTI5.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNkZWx0YSslNUUlN0IlMjhsJTI5JTdEJTNEJTVDZnJhYyU3QiU1Q3BhcnRpYWwrSiU3RCU3QiU1Q3BhcnRpYWwreiU1RSU3QiUyOGwlMjklN0QlN0QlM0QlNUNmcmFjJTdCJTVDcGFydGlhbCtKJTdEJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyQjElMjklN0QlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreiU1RSU3QiUyOGwlMkIxJTI5JTdEJTdEJTdCJTVDcGFydGlhbCthJTVFJTdCJTI4bCUyOSU3RCU3RCU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2ElNUUlN0IlMjhsJTI5JTdEJTdEJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyOSU3RCU3RCUzRCUwQSUyOCUyOFclNUUlN0IlMjhsJTI5JTdEJTI5JTVFJTdCVCU3RCU1Q2RlbHRhKyU1RSU3QiUyOGwlMkIxJTI5JTdEJTI5KyU1Q2Nkb3QrZiUyNyUyOHolNUUlN0IlMjhsJTI5JTdEJTI5.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiaWd0cmlhbmdsZWRvd25fJTdCVyU1RSU3QiUyOGwlMjklN0QlN0RKJTI4VyUyQ2IlMjklM0QlNUNmcmFjJTdCJTVDcGFydGlhbCtKJTI4VyUyQ2IlMjklN0QlN0IlNUNwYXJ0aWFsK3olNUUlN0IlMjhsJTJCMSUyOSU3RCU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyQjElMjklN0QlN0QlN0IlNUNwYXJ0aWFsK1clNUUlN0IlMjhsJTI5JTdEJTdEJTNEJTVDZGVsdGErJTVFJTdCJTI4bCUyQjElMjklN0QlMjhhKyU1RSU3QiUyOGwlMjklN0QlMjklNUVU.png)


![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1pbiU1Q19kaWZmKyUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtKJTdEJTdCJTVDcGFydGlhbCt6JTVFJTdCJTI4bCUyQjElMjklN0QlN0QrJTNEKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK0olN0QlN0IlNUNwYXJ0aWFsK2ElNUUlN0IlMjhsJTJCMSUyOSU3RCU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCthJTVFJTdCJTI4bCUyQjElMjklN0QlN0QlN0IlNUNwYXJ0aWFsK3olNUUlN0IlMjhsJTJCMSUyOSU3RCU3RCslM0Qrb3V0JTVDX2RpZmYrJTVDY2RvdCtvdXQrJTVDY2RvdCslMjgxLW91dCUyOQ==.png)