相信每一個剛剛入門神經網絡(現在叫深度學習)的同學都一定在反向傳播的梯度推導那里被折磨了半天。在各種機器學習的課上明明聽得非常明白,神經網絡無非就是正向算一遍Loss,反向算一下每個參數的梯度,然后大家按照梯度更新就好了。問題是梯度到底怎么求呢?課上往往舉的是標量的例子,可是一到你做作業的時候就發現所有的東西都是vectorized的,一個一個都是矩陣。矩陣的微分操作大部分人都是不熟悉的,結果使得很多人在梯度的推導這里直接選擇死亡。我曾經就是其中的一員,做CS231n的Assignment 1里面那幾個簡單的小導數都搞得讓我懷疑人生。
我相信很多人都看了不少資料,比如CS231n的講師Karpathy推薦的這一篇矩陣求導指南http://cs231n.stanford.edu/vecDerivs.pdf,但是經過了幾天的折磨以后,我發現事實上根本就不需要去學習這些東西。在神經網絡中正確計算梯度其實非常簡單,只需要把握好下面的兩條原則即可。這兩條原則非常適合對矩陣微分不熟悉的同學,雖然看起來並不嚴謹,但是有效。
1. 用好維度分析,不要直接求導
神經網絡中求梯度,第一原則是:如果你對矩陣微分不熟悉,那么永遠不要直接計算一個矩陣對另一個矩陣的導數。我們很快就可以看到,在神經網絡中,所有的矩陣對矩陣的導數都是可以通過間接的方法,利用求標量導數的那些知識輕松求出來的。而這種間接求導數的方法就是維度分析。我認為維度分析是神經網絡中求取梯度最好用的技巧,沒有之一。用好維度分析,你就不用一個一個地去分析矩陣當中每個元素究竟是對誰怎么求導的,各種求和完了以后是左乘還是右乘,到底該不該轉置等等破事,簡直好用的不能再好用了。這一技巧在Karpathy的Course Note上也提到了一點。
什么叫維度分析?舉一個最簡單的例子。設某一層的Forward Pass為,X是NxD的矩陣,W是DxC的矩陣,b是1xC的矩陣,那么score就是一個NxC的矩陣。現在上層已經告訴你L對score的導數是多少了,我們求L對W和b的導數。
我們已經知道一定是一個NxC的矩陣(因為Loss是一個標量,score的每一個元素變化,Loss也會隨之變化),那么就有
現在問題來了,score是一個矩陣,W也是個矩陣,矩陣對矩陣求導,怎么求啊?如果你對矩陣微分不熟悉的話,到這里就直接懵逼了。於是很多同學都出門右轉去學習矩陣微分到底怎么搞,看到那滿篇的推導過程就感到一陣惡心,之后就提前走完了從入門到放棄,從深度學習到深度厭學的整個過程。
其實我們沒有必要直接求score對W的導數,我們可以利用另外兩個導數間接地把算出來。首先看看它是多大的。我們知道
一定是DxC的(和W一樣大),而
是NxC的,哦那你瞬間就發現了
一定是DxN的,因為(DxN)x(NxC)=>(DxC),並且你還發現你隨手寫的這個式子右邊兩項寫反了,應該是
。
那好,我們已經知道了是DxN的,那就好辦了。既然score=XW+b,如果都是標量的話,score對W求導,本身就是X;X是NxD的,我們要DxN的,那就轉置一下唄,於是我們就得出了:
完事了。
你看,我們並沒有直接去用諸如這種細枝末節的一個一個元素求導的方式推導
,而是利用
再加上熟悉的標量求導的知識,就把這個矩陣求導給算出來了。這就是神經網絡中求取導數的正確姿勢。
為什么這一招總是有效呢?這里的關鍵點在於Loss是一個標量,而標量對一個矩陣求導,其大小和這個矩陣的大小永遠是一樣的。那么,在神經網絡里,你永遠都可以執行這個“知二求一”的過程,其中的“二”就是兩個Loss對參數的導數,另一個是你不會求的矩陣對矩陣的導數。首先把你沒法直接求的矩陣導數的大小給計算出來,然后利用你熟悉的標量求導的方法大概看看導數長什么樣子,最后湊出那個目標大小的矩陣來就好了。
那呢?我們來看看,
是NxC的,
是1xC的,
看起來像1,那聰明的你肯定想到
其實就是1xN個1了,因為(1xN)x(NxC)=>(1xC)。其實這也就等價於直接對d_score的第一維求個和,把N降低成1而已。
多說一句,這個求和是怎么來的?原因實際上在於所謂的“廣播”機制。你會發現,XW是一個NxC的矩陣,但是b只是一個1xC的矩陣,按理說,這倆矩陣形狀不一樣,是不能相加的。但是我們都知道,實際上我們想做的事情是讓XW的每一行都加上b。也就是說,我們把b的第一維復制了N份,強行變成了一個NxC的矩陣,然后加在了XW上(當然這件事實際上是numpy幫你做的)。那么,當你要回來求梯度的時候,既然每一個b都參與了N行的運算,那就要把每一份的梯度全都加起來求個和的。因為求導法則告訴我們,如果一個變量參與了多個運算,那就要把它們的導數加起來。這里借用一下
的圖,相信大家可以看得更明白。
總之,不要試圖在神經網絡里面直接求矩陣對矩陣的導數,而要用維度分析間接求,這樣可以為你省下很多不必要的麻煩。
2. 用好鏈式法則,不要一步到位
我曾經覺得鏈式法則簡直就是把簡單的問題搞復雜,復合函數求導這種東西高考的時候我們就都會了,還用得着一步一步地往下拆?比如,我一眼就能看出來
,還用得着先把
當成一個中間函數么?
不幸的是,在神經網絡里面,你會發現事情沒那么容易。上面的這些推導只在標量下成立,如果w,x和b都是矩陣的話,我們很容易就感到無從下筆。還舉上面這個例子,設,我們要求
,那么我們直接就可以寫出
L對H的導數,是反向傳播當中上一層會告訴你的,但問題是H對W的導數怎么求呢?
如果你學會了剛才的維度分析法,那么你可能會覺得是一個DxN的矩陣。然后就會發現沒有任何招可以用了。事實上,卡殼的原因在於,
根本不是一個矩陣,而是一個4維的tensor。對這個鬼玩意的運算初學者是搞不定的。准確的講,它也可以表示成一個矩陣,但是它的大小並不是DxN,而且它和
的運算也不是簡單的矩陣乘法,會有向量化等等的過程。有興趣的同學可以參考這篇文章,里面有一個例子講解了如何直接求這個導數:矩陣求導術(下)。
這是一個剛學完反向傳播的初學者很容易踩到的陷阱:試圖不設中間變量,直接就把目標參數的梯度給求出來。如果這么去做的話,很容易在中間碰到這種非矩陣的結構,因為理論上矩陣對矩陣求導求出來是一個4維tensor,不是我們熟悉的二維矩陣。除非你完全掌握了上面那篇reference當中的數學技巧,不然你就只能干瞪眼了。
但是,如果你不直接求取對W的導數,而把當做一個中間變量的話,事情就簡單的多了。因為如果每一步求導都只是一個簡單二元運算的話,那么即使是矩陣對矩陣求導,求出來也仍然是一個矩陣,這樣我們就可以用維度分析法往下做了。
設,則有
利用維度分析:dS是NxC的,dH是NxC的,考慮到,那么容易想到
也是NxC的,也就是
,這是一個element-wise的相乘;所以
;
再求,用上一部分的方法,很容易求得
,所以就求完了。
有了這些結果,我們不妨回頭看看一開始的那個式子:,如果你錯誤地認為
是一個DxN的矩陣的話,再往下運算:
我們已經知道,這兩個矩陣一個是NxC的,一個是DxN的,無論怎么相乘,也得不出DxN的矩陣。矛盾就是出在H對W的導數其實並不是一個矩陣。但是如果使用鏈式法則運算的話,我們就可以避開這個復雜的tensor,只使用矩陣運算和標量求導就搞定神經網絡中的梯度推導。
借助這兩個技巧,已經足以計算任何復雜的層的梯度。下面我們來實戰一個:求Softmax層的梯度。
Softmax層往往是輸出層,其Forward Pass公式為:
,
,
假設輸入X是NxD的,總共有C類,那么W顯然應該是DxC的,b是1xC的。其中就是第i個樣本預測的其正確class的概率。關於softmax的知識在這里就不多說了。我們來求Loss關於W, X和b的導數。為了簡便起見,下面所有的d_xxx指的都是Loss對xxx的導數。
我們首先把Loss重新寫一下,把P代入進去:
不要一步到位,我們把前面一部分和后面一部分分開看。設, rowsum就是每一行的score指數和,因此是Nx1的,那么就有
先看d_score,其大小與score一樣,是NxC的。你會發現如果扔掉前面的1/N不看,d_score其實就是一堆0,然后在每一行那個正確的class那里為-1;寫成python代碼就是
d_score = np.zeros_like(score)
d_score[range(N),y] -= 1
然后看d_rowsum,其實就是,非常簡單。
現在我們關注,需要注意的是我們不要直接求
是什么,兩個都是矩陣,不好求;相反,我們求
是多少。我們會發現上面我們求了一個d_score,這里又求了一個d_score,這說明score這個矩陣參與了兩個運算,這是符合這里Loss的定義的。求導法則告訴我們,當一個變量參與了兩部分運算的時候,把這兩部分的導數加起來就可以了。
這一部分的d_score就很好求了:
,左邊是NxC的,右邊已知的是Nx1的,那么剩下的有可能是1xC的,也有可能是NxC的。這個時候就要分析一下了。我們會發現右邊應該是NxC的,因為每一個score都只影響一個rowsum的元素,因此我們不應該求和。NxC的矩陣就是
自己,所以我們就很容易得出:
# 實際上,d_rowsum往往是一個長度為N的一位數組,因此我們先用np.newaxis把它的shape由N升維到Nx1, # 這樣就可以使用廣播機制(Nx1 * NxC) # 然后用乘號做element wise相乘。 d_score += d_rowsum[:, np.newaxis] * (np.exp(score)) d_score /= N #再把那個1/N給補上
這樣我們就完成了對score的求導,之后score對W, X和b的求導,相信你也就會了。
當然,如果你注意一下的話,你會發現其實第二部分的那個式子就是P矩陣。不過如果你沒有注意到這一點也無所謂,用這套方法也可以求出d_score是多少。
利用同樣的方法,現在看看那個卡住無數人的Batch Normalization層的梯度推導,是不是也感到不那么困難了?
希望本文可以為剛剛入門神經網絡的同學提供一些幫助,如有錯漏歡迎指出。
謝謝!