2021-03-04
數值求導和自動求導
早在高中階段,我們就開始接觸導數,了解過常用函數的求導公式。大學時,我們進一步懂得了用極限定義導數,比如,函數
在
處的導數定義為
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1mJTI3JTVDbGVmdCUyOCt4KyU1Q3JpZ2h0JTI5KyUzRCslNUNsaW1fJTdCJTVDZXBzaWxvbislNUNyaWdodGFycm93KzAlN0QlN0IlNUNmcmFjJTdCZiU1Q2xlZnQlMjgreCslMkIrJTVDZXBzaWxvbislNUNyaWdodCUyOSstK2YrJTVDbGVmdCUyOCt4KyU1Q3JpZ2h0JTI5JTdEJTdCJTVDZXBzaWxvbiU3RCU3RCslNUMlNUM=.png)
然而,這個定義式似乎從來沒有派上過用場,始終束之高閣。因為對我們來說,這個式子是沒法計算的,
趨近於
,超出了我們手工計算的能力范疇。另一方面,各種求導公式和求導法則可以讓我們直接求出導函數的解析解,所以完全不需要根據定義來計算某個點的導數。
但出了校園,形勢就變得微妙起來。很多時候,我們無法求得導函數的解析解,或者求解析解的代價太大。這種情況下,數值求導和自動求導應運而生,這是數學與計算機科學碰撞出的火花,我們將會看到它在工程領域綻放的光芒。
數值求導
如果函數
是個黑盒子,即我們不知道它的解析形式,但給定任意的輸入
,我們都能得到唯一的輸出
。這就是典型的計算機程序的特點,函數的內部實現被封裝在庫中,調用方得不到庫的源碼,這時候要想求函數的導數,就只能用數值求導方法。
本文介紹的數值求導方法,也是最常用的數值求導方法,稱為有限差分(Finite-Difference)。該方法正是源自文章開頭介紹的導數定義式。應用該式時,我們需要去掉
運算符,用一個實際的小量
代入,但是這樣做勢必會引入誤差。顯然,
越小,誤差也會越小,但到底誤差與
之間是怎樣的關系,線性關系還是二次關系,就需要嚴謹的推導才能得知。
使用泰勒定理展開
可得
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1mJTI4eCUyQiU1Q2Vwc2lsb24lMjkrJTNEK2YlMjh4JTI5KyUyQitmJTI3JTI4eCUyOSU1Q2Vwc2lsb24rJTJCKyU1Q2ZyYWMlN0IxJTdEJTdCMiU3RGYlMjclMjclMjh4KyUyQit0JTVDZXBzaWxvbiUyOSU1Q2Vwc2lsb24lNUUyJTJDKyU1Q3F1YWQrJTVDdGV4dCU3QnNvbWUrJTdEK3QrJTVDaW4rJTI4MCUyQysxJTI5KyU1Q3RhZyU3QjcuMiU3RCslNUMlNUMr.png)
如果函數的二階導有限,即存在使得
成立的
,我們就可以把式(7.2)做一次放縮,得到
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUMlN0MrZiUyOHglMkIlNUNlcHNpbG9uJTI5Ky0rZiUyOHglMjkrLStmJTI3JTI4eCUyOSU1Q2Vwc2lsb24rJTVDJTdDKyU1Q2xlKyUyOEwlMkYyJTI5KyU1Q2Vwc2lsb24lNUUyKyU1Q3RhZyU3QjcuMyU3RCslNUMlNUM=.png)
進而求得導數
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1mJTI3JTI4eCUyOSsrJTNEKyU1Q2ZyYWMlN0JmJTI4eCslMkIrJTVDZXBzaWxvbiUyOSstK2YlMjh4JTI5JTdEJTdCJTVDZXBzaWxvbiU3RCslMkIrJTVDZGVsdGFfJTVDZXBzaWxvbiUyQyslNUNxdWFkKyU1Q3RleHQlN0J3aGVyZSslN0QrJTdDKyU1Q2RlbHRhXyU1Q2Vwc2lsb24lN0MrJTVDbGUrJTI4TCUyRjIlMjklNUNlcHNpbG9uKyU1Q3RhZyU3QjcuNCU3RCslNUMlNUM=.png)
式(7.4)把式(7.3)中的不等號轉化為一個小量
,該小量的上界由
決定。此時,我們可以說,有限差分的結果與真實導數值之間的誤差在
這個量級。
看起來似乎讓
越小越好,但實際並沒有這么簡單。計算機的精度是有上限的,任何運算都不可能得到完全精確的結果,而是會存在舍入誤差。假設用來計算式
的量都用雙精度浮點數表示,我們來看看計算機中這些數會有怎樣的舍入誤差。在學習C語言的時候一般會提到,雙精度浮點數由64個二進制位組成,IEEE規定它們的排布如下

最高位是符號位,接下來的11位是指數位,剩下的52位是小數位。現在我們考慮一個問題,對於任意一個可以用浮點數表示的實數,它的真實值與其浮點數表示的值之間相差多少?咋看上去,這個問題的答案似乎是浮點數所能表示的最接近0的實數,也就是指數位取0、小數位取1對應的數。但實際上,由於有效數字位數只有52位,指數越大,有效數字的最低位對應的大小就越大,所以答案是不確定的。我們只能確保真實值與其浮點數表示的值的前52位二進制位是完全相同的,換算成10進制大概是15位。舉例來說,
與其浮點數表示相差大約1,而1與其浮點數表示相差大約
。由此推廣到任意一個數
,它的浮點數可以表示為
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUN0ZXh0JTdCZmwlN0QlMjh4JTI5JTNEeCUyODElMkIlNUNlcHNpbG9uJTI5JTJDKyU1Q3F1YWQrJTVDdGV4dCslN0Ird2hlcmUrJTdEJTdDJTVDZXBzaWxvbiU3QyslNUNsZXErJTVDbWF0aGJmJTdCdSU3RCslNUN0YWclN0JBLjU4JTdEKyU1QyU1Qw==.png)
其中
。可以發現,
其實表示的是相對精度。現在回到式(7.4),如果我們用計算機計算的結果代替理想值,即用
替代
,用
替代
,結果勢必會變化,由式(A.58)可知
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrJTdDKyU1Q3RleHQlN0JmbCU3RCUyOGYlMjh4JTI5JTI5Ky0rZiUyOHglMjklN0MrKyUyNiUzRCslN0NmJTI4eCUyOSslMkIrZiUyOHglMjklNUNlcHNpbG9uKy0rZiUyOHglMjklN0MrJTVDJTVDKyUyNiUzRCslN0NmJTI4eCUyOSslNUNlcHNpbG9uJTdDKyU1QyU1QyslMjYlNUNsZStMX2YrJTdDJTVDZXBzaWxvbiU3QyslNUMlNUMrJTI2JTVDbGUrTF9mKyU1Q21hdGhiZiU3QnUlN0QrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUMlNUM=.png)
其中,
是函數值的界限。同理,
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lN0MrJTVDdGV4dCU3QmZsJTdEJTI4ZiUyOHglMkIlNUNlcHNpbG9uJTI5JTI5Ky0rZiUyOHglMkIlNUNlcHNpbG9uJTI5JTdDKyU1Q2xlK0xfZislNUNtYXRoYmYlN0J1JTdEKyU1QyU1Qw==.png)
這意味着
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUN0ZXh0JTdCZmwlN0QlMjhmJTI4eCUyOSUyOSslM0QrZiUyOHglMjkrJTJCKyU1Q2RlbHRhXyU3QiU1Q2Vwc2lsb25fMSU3RCUyQyslNUNxdWFkKyU1Q3RleHQlN0J3aGVyZSslN0QrJTdDJTVDZGVsdGFfJTdCJTVDZXBzaWxvbl8xJTdEJTdDKyU1Q2xlK0xfZislNUNtYXRoYmYlN0J1JTdEKyU1QyU1QyslNUN0ZXh0JTdCZmwlN0QlMjhmJTI4eCslMkIrJTVDZXBzaWxvbiUyOSUyOSslM0QrZiUyOHgrJTJCKyU1Q2Vwc2lsb24lMjkrJTJCKyU1Q2RlbHRhXyU3QiU1Q2Vwc2lsb25fMiU3RCUyQyslNUNxdWFkKyU1Q3RleHQlN0J3aGVyZSslN0QrJTdDJTVDZGVsdGFfJTdCJTVDZXBzaWxvbl8yJTdEJTdDKyU1Q2xlK0xfZislNUNtYXRoYmYlN0J1JTdEKyU1QyU1Qw==.png)
此時,我們再按照式(7.4)的計算方式計算導數的浮點數表示
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrJTVDdGV4dCU3QmZsJTdEJTI4ZiUyNyUyOHglMjklMjkrKyslMjYlM0QrJTVDZnJhYyU3QiU1Q3RleHQlN0JmbCU3RCUyOGYlMjh4KyUyQislNUNlcHNpbG9uJTI5JTI5Ky0rJTVDdGV4dCU3QmZsJTdEJTI4ZiUyOHglMjklMjklN0QlN0IlNUNlcHNpbG9uJTdEKyUyQislNUNkZWx0YV8lNUNlcHNpbG9uJTVDJTVDKyUyNiUzRCslNUNmcmFjJTdCZiUyOHgrJTJCKyU1Q2Vwc2lsb24lMjkrLStmJTI4eCUyOSslMkIrJTVDZGVsdGFfJTdCJTVDZXBzaWxvbl8yJTdEKy0rJTVDZGVsdGFfJTdCJTVDZXBzaWxvbl8xJTdEJTdEJTdCJTVDZXBzaWxvbiU3RCslMkIrJTVDZGVsdGFfJTVDZXBzaWxvbislNUMlNUMrJTI2JTNEKyU1Q2ZyYWMlN0JmJTI4eCslMkIrJTVDZXBzaWxvbiUyOSstK2YlMjh4JTI5JTdEJTdCJTVDZXBzaWxvbiU3RCslMkIrJTVDZGVsdGErJTVDZW5kJTdCYWxpZ25lZCU3RCslMkMrJTVDcXVhZCslNUN0ZXh0JTdCd2hlcmUrJTdEKyU3QyslNUNkZWx0YSU3QyslNUNsZSslMjhMJTJGMiUyOSU1Q2Vwc2lsb24rJTJCKzJMX2YlNUNtYXRoYmYlN0J1JTdEJTJGJTVDZXBzaWxvbislNUN0YWclN0I3LjUlN0QrJTVDJTVD.png)
現在,誤差不只和
有關,還與計算機的相對精度
有關。如果我們把
對
求導,令導數等於0,可以求得當
時,
取極小值。這個結論說明,
並不是越小越好。實際應用中,我們假設
不會太大也不會太小,所以取
即可。
這種有限差分方法稱為前向差分(forward-difference),因為只計算了當前點和前面的點的函數值,可以想象,這種方法算出來的結果總是會有一些偏差。更好的方法是中心差分(central-difference),定義如下
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1mJTI3JTI4eCUyOSsrJTVDYXBwcm94KyU1Q2ZyYWMlN0JmJTI4eCslMkIrJTVDZXBzaWxvbiUyOSstK2YlMjh4Ky0rJTVDZXBzaWxvbiUyOSU3RCU3QjIlNUNlcHNpbG9uJTdEKyU1Q3RhZyU3QjcuNyU3RCslNUMlNUM=.png)
如果我們仍按照前面的方式分別對
和
泰勒展開,可以發現式(7.7)的近似誤差在
數量級。
那是不是中心差分一定比前向差分好呢?未必。實際中,函數的自變量
通常是一個向量。此時導數也是個向量,有限差分方法需要分別對自變量的各個分量求導。假設
是
維向量,那么前向差分需要計算
次函數值,而中心差分則需要計算
次函數值。可見,中心差分精度高的代價是計算量大,實際使用時需要有所取舍。
稀疏雅克比
目標函數的導數又稱為雅克比矩陣。對於標量函數來說,雅可比矩陣是個行向量,形式如下
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNuYWJsYStmJTI4eCUyOSUzRCU1Q2xlZnQlNUIlNUNiZWdpbiU3QmFycmF5JTdEJTdCJTdEKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2YlN0QlN0IlNUNwYXJ0aWFsK3hfMSU3RCslMjYrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreF8yJTdEKyUyNislNUNjZG90cyslMjYrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreF9uJTdEKyU1Q2VuZCU3QmFycmF5JTdEJTVDcmlnaHQlNUQrJTVDJTVD.png)
剛剛提到過,這種函數的前向差分需要計算
次函數值。而對於向量函數,雅可比矩陣是個
的矩陣,
是目標函數的維度,形式如下
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmVxdWF0aW9uJTJBJTdEKyU1Q25hYmxhK2YlMjh4JTI5KyUzRCsrJTVDYmVnaW4lN0JibWF0cml4JTdEKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2ZfMSU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyUyNislNUNmcmFjJTdCJTVDcGFydGlhbCtmXzElN0QlN0IlNUNwYXJ0aWFsK3hfMiU3RCslMjYrJTVDY2RvdHMrJTI2KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2ZfMSU3RCU3QiU1Q3BhcnRpYWwreF9uJTdEKyU1QyU1QyslNUNmcmFjJTdCJTVDcGFydGlhbCtmXzIlN0QlN0IlNUNwYXJ0aWFsK3hfMSU3RCslMjYrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrZl8yJTdEJTdCJTVDcGFydGlhbCt4XzIlN0QrJTI2KyU1Q2Nkb3RzKyUyNiU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2ZfMiU3RCU3QiU1Q3BhcnRpYWwreF9uJTdEKyU1QyU1QyslNUN2ZG90cysrJTI2KyU1Q3Zkb3RzKyslMjYrKyUyNislNUN2ZG90cysrJTVDJTVDKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2ZfbSU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyUyNislNUNmcmFjJTdCJTVDcGFydGlhbCtmX20lN0QlN0IlNUNwYXJ0aWFsK3hfMiU3RCslMjYrJTVDY2RvdHMrJTI2KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2ZfbSU3RCU3QiU1Q3BhcnRpYWwreF9uJTdEKyslNUNlbmQlN0JibWF0cml4JTdEKyU1Q2VuZCU3QmVxdWF0aW9uJTJBJTdEKyU1QyU1Qw==.png)
這時候,矩陣的每一個元素都需要計算一次函數值(這里所說的函數值是指向量函數中的結果向量的每一維),總共需要計算
次。
於是,輸入變量和目標函數的維度越高,雅可比的計算量越大,數值求導很容易變得不再可行。好在某些實際問題中,雅可比矩陣會呈現出稀疏的特征,比如下面這個樣子
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmJtYXRyaXglN0QrJUMzJTk3KyUyNislQzMlOTcrJTI2KyUyNislMjYrJTI2KyU1QyU1QyslQzMlOTcrJTI2KyVDMyU5NyslMjYrJUMzJTk3JTI2KyUyNislMjYrJTVDJTVDKyUyNislQzMlOTclMjYrJUMzJTk3KyUyNislQzMlOTcrJTI2KyUyNislNUMlNUMrJTI2KyUyNislQzMlOTclMjYrJUMzJTk3JTI2KyVDMyU5NyUyNislNUMlNUMrJTI2KyUyNislMjYrJUMzJTk3JTI2KyVDMyU5NyUyNislQzMlOTcrJTVDJTVDKyUyNislMjYrJTI2KyUyNislQzMlOTcrJTI2KyVDMyU5NysrJTVDZW5kJTdCYm1hdHJpeCU3RCslNUN0YWclN0I3LjEzJTdEKyU1QyU1Qw==.png)
只有對角的條帶狀元素有值,其它位置都是0。如何利用這一特點加速雅可比計算呢?
把有限微分公式推廣到向量情況,對於雅可比矩陣中的任意一個值
,需要取
,其中
是第
維的單位基向量。仔細觀察式(7.13)的第一列和第四列,可以發現,
只對導數中的
和
分量產生影響,
只對導數中的
、
和
分量產生影響,這說明
和
關於導數值是相互獨立的。這樣我們就可以取
,一次性計算出
和
。
這就是稀疏雅可比帶來的效率提升。推廣到一般情況,即使雅可比矩陣的形式不是條帶狀,也存在一些方法可以找到相互獨立的若干個分量,從而加速計算。不過可以預見的是,對任意形狀的雅可比矩陣搜索獨立分量並不容易,它本質上是一個圖指派問題(graph assignment)。
條帶狀的雅可比矩陣並不罕見,如果你接觸過視覺SLAM,就會發現視覺SLAM中的雅可比矩陣都是條帶狀的。這是因為在整個時間序列中,某個時間點的誤差項只和少數的若干個位姿和路標點發生關聯,你不可能在任何位置觀測到所有的路標點。
既然雅可比矩陣可以有限差分,海森矩陣也可以。而且海森矩陣是對稱的,所以我們只需要計算一半的元素即可。海森矩陣也有常見的稀疏形式,通常是箭頭狀,利用這一稀疏性也可以大大減少計算量。由於篇幅限制,本文就不詳細介紹了。
自動求導
數值求導很方便,不需要知道目標函數的表達式就可以計算。但是大多數情況是,我們知道目標函數的表達式,但這個表達式太復雜,很難手動計算導數的解析形式。或者目標函數是隨時變化的,無法事先求出。一個典型的例子就是神經網絡,每次我們更改網絡的層數、維度、激活函數、損失函數等等,都會使得目標函數發生變化。手動求導是不可行的,數值求導雖然可行但精度難以保證,且巨大的自變量維度會導致雅可比矩陣計算量突破天際。於是,自動求導的出現解決了這個問題。
很多人搞不清楚數值求導、符號求導和自動求導的區別。本文沒有提到符號求導,這是一種直接用計算機計算導數解析解的方法,存在於MATLAB、Mathematica等數值計算軟件中。而自動求導不同於符號求導,它並不計算解析解,而是利用鏈式法則將復雜的函數拆分成一個個獨立的計算單元,分別對這些小單元求導,最后再合並得到完整的導數。將函數拆分為多個計算單元后,我們稱之為計算圖(computational graph),這是一個有向無環圖,標記了數據的流向。我們用一個例子來詳細介紹這一方法。
目標函數
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1mJTI4eCUyOSUzRCU1Q2xlZnQlMjh4XyU3QjElN0QreF8lN0IyJTdEKyU1Q3Npbit4XyU3QjMlN0QlMkJlJTVFJTdCeF8lN0IxJTdEK3hfJTdCMiU3RCU3RCU1Q3JpZ2h0JTI5KyUyRit4XyU3QjMlN0QrJTVDdGFnJTdCNy4yNiU3RCslNUMlNUM=.png)
拆分為多個計算單元
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFycmF5JTdEJTdCbCU3RCt4XyU3QjQlN0QlM0R4XyU3QjElN0QrJTJBK3hfJTdCMiU3RCslNUMlNUMreF8lN0I1JTdEJTNEJTVDc2luK3hfJTdCMyU3RCslNUMlNUMreF8lN0I2JTdEJTNEZSU1RSU3QnhfJTdCNCU3RCU3RCslNUMlNUMreF8lN0I3JTdEJTNEeF8lN0I0JTdEKyUyQSt4XyU3QjUlN0QrJTVDJTVDK3hfJTdCOCU3RCUzRHhfJTdCNiU3RCUyQnhfJTdCNyU3RCslNUMlNUMreF8lN0I5JTdEJTNEeF8lN0I4JTdEKyUyRit4XyU3QjMlN0QrJTVDZW5kJTdCYXJyYXklN0QrJTVDdGFnJTdCNy4yNyU3RCslNUMlNUM=.png)
按照計算順序將其表示成有向圖如下

我們最終的目的是求導數
。以第一項
為例,根據鏈式法則
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyUyNiUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4XzglN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreF84JTdEJTdCJTVDcGFydGlhbCt4XzElN0QrJTVDJTVDKyUyNiUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4XzglN0QrJTVDbGVmdCUyOCslNUNmcmFjJTdCJTVDcGFydGlhbCt4XzglN0QlN0IlNUNwYXJ0aWFsK3hfNyU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCt4XzclN0QlN0IlNUNwYXJ0aWFsK3hfMSU3RCslMkIrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreF84JTdEJTdCJTVDcGFydGlhbCt4XzYlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreF82JTdEJTdCJTVDcGFydGlhbCt4XzElN0QrJTVDcmlnaHQlMjkrJTVDJTVDKyUyNiUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4XzglN0QrJTVDbGVmdCUyOCslNUNmcmFjJTdCJTVDcGFydGlhbCt4XzglN0QlN0IlNUNwYXJ0aWFsK3hfNyU3RCslNUNsZWZ0JTI4KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3hfNyU3RCU3QiU1Q3BhcnRpYWwreF80JTdEKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3hfNCU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyU1Q3JpZ2h0JTI5KyUyQislNUNmcmFjJTdCJTVDcGFydGlhbCt4XzglN0QlN0IlNUNwYXJ0aWFsK3hfNiU3RCslNUNsZWZ0JTI4KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3hfNiU3RCU3QiU1Q3BhcnRpYWwreF80JTdEKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3hfNCU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyU1Q3JpZ2h0JTI5KyU1Q3JpZ2h0JTI5KyU1QyU1QyslNUNlbmQlN0JhbGlnbmVkJTdEKyU1QyU1Qw==.png)
這樣分解之后,每一個需要計算的項都是相鄰節點之間簡單函數的導數。而且計算圖從左到右對應着上式從內到外,我們只需要一次前向傳播(forward sweep)即可得出結果。
自動求導就是這么簡單,但還不止於此。如果我們換種方式使用鏈式法則,像下面這樣
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyUyNiUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4XzQlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreF80JTdEJTdCJTVDcGFydGlhbCt4XzElN0QrJTVDJTVDKyUyNiUzRCsrJTVDbGVmdCUyOCslNUNmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4XzYlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreF82JTdEJTdCJTVDcGFydGlhbCt4XzQlN0QrJTJCKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK2YlN0QlN0IlNUNwYXJ0aWFsK3hfNyU3RCslNUNmcmFjJTdCJTVDcGFydGlhbCt4XzclN0QlN0IlNUNwYXJ0aWFsK3hfNCU3RCslNUNyaWdodCUyOSslNUNmcmFjJTdCJTVDcGFydGlhbCt4XzQlN0QlN0IlNUNwYXJ0aWFsK3hfMSU3RCU1QyU1QyslMjYlM0QrKyU1Q2xlZnQlMjgrJTVDbGVmdCUyOCslNUNmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4XzglN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwreF84JTdEJTdCJTVDcGFydGlhbCt4XzYlN0QrJTVDcmlnaHQlMjklNUNmcmFjJTdCJTVDcGFydGlhbCt4XzYlN0QlN0IlNUNwYXJ0aWFsK3hfNCU3RCsrJTJCKyU1Q2xlZnQlMjgrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreF84JTdEKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3hfOCU3RCU3QiU1Q3BhcnRpYWwreF83JTdEKyU1Q3JpZ2h0JTI5JTVDZnJhYyU3QiU1Q3BhcnRpYWwreF83JTdEJTdCJTVDcGFydGlhbCt4XzQlN0QrKyU1Q3JpZ2h0JTI5KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3hfNCU3RCU3QiU1Q3BhcnRpYWwreF8xJTdEKyU1QyU1QyslNUNlbmQlN0JhbGlnbmVkJTdEKyU1QyU1Qw==.png)
有沒有發現和上面的差別?第一種鏈式法則是把計算圖從右到左依次展開的,而第二種鏈式法則是把計算圖從左到右依次展開的。這樣的結果是,第二種方法需要先計算最右側的導數,然后依次向左計算。因此它不僅需要一次前向傳播,還需要一次反向傳播(reverse sweep)。
在自動求導中,第一種方法稱為前向模式(forward mode),第二種方法稱為逆向模式(reverse mode)。看起來似乎逆向模式計算更復雜,而且需要保存前向傳播產生的中間變量,但實際的深度學習框架中都采用的是逆向模式,這是為什么呢?
這是由目標函數的形式決定的。在深度學習以及大部分優化問題中,目標函數都是標量函數
,自變量很多,但函數值只有一個。這種情況下,使用前向模式,需要分別對每個自變量做一次前向傳播,也就是
次前向傳播。如果使用逆向模式,雖然從上面的公式看起來好像也需要
次計算,但實際上計算的時候是從括號內部向外逐項計算的,對應到計算圖上是從右向左計算。這和從左向右很不一樣,因為從右向左的起始點只有一個,因此算出來的結果最終對左側的所有輸入節點都是通用的。也就是說,從右向左計算的中間節點的值可以共享給所有輸入變量,從而減少計算次數。
這個道理可以推廣到向量目標函數
。使用前向模式,我們需要完整地計算
次前向傳播。使用逆向模式,我們需要完整地計算
次反向傳播。因此,當
時,使用逆向模式比較好,當
時,使用前向模式比較好。
與數值求導中的稀疏雅可比類似,自動求導時也可以考慮稀疏性。仍然是找到計算圖中相互獨立的輸入節點,在前向傳播時同時考慮這些輸入即可。
總結
數值求導和自動求導是數值最優化中的基本操作,了解它們的原理,可以幫助我們更深刻地理解優化算法的性能。本文的講解比較粗淺,旨在幫助大家了解基本概念,感興趣的同學可以繼續查閱相關資料。
見:https://zhuanlan.zhihu.com/p/109755675
