在訓練模型時,我們可以基於梯度使用不同的優化器(optimizer,或者稱為“優化算法”)來最小化損失函數。這篇文章對常用的優化器進行了總結。
BGD
BGD 的全稱是 Batch Gradient Descent,中文名稱是批量梯度下降。顧名思義,BGD 根據整個訓練集計算梯度進行梯度下降
其中,\(J(\theta)\) 是根據整個訓練集計算出來的損失。
- 優點
- 當損失函數是凸函數(convex)時,BGD 能收斂到全局最優;當損失函數非凸(non-convex)時,BGD 能收斂到局部最優;
- 缺點
- 每次都要根據全部的數據來計算梯度,速度會比較慢;
- BGD 不能夠在線訓練,也就是不能根據新數據來實時更新模型;
SGD
SGD 的全稱是 Stochastic Gradient Descent,中文名稱是隨機梯度下降。和 BGD 相反,SGD 每次只使用一個訓練樣本來進行梯度更新:
其中,\(J(\theta;x^{(i)};y^{(i)})\) 是只根據樣本 \((x^{(i)};y^{(i)})\) 計算出的損失。
- 優點
- SGD 每次只根據一個樣本計算梯度,速度較快;
- SGD 可以根據新樣本實時地更新模型;
- 缺點
- SGD 在優化的過程中損失的震盪會比較嚴重;
- SGD 在優化的過程中損失的震盪會比較嚴重;
MBGD
MBGD 的全稱是 Mini-batch Gradient Descent,中文名稱是小批量梯度下降。MBGD 是 BGD 和 SGD 的折中。MBGD 每次使用包含 m 個樣本的小批量數據來計算梯度
其中,\(m\) 為小批量的大小,范圍是 \([1, n]\),\(n\) 為訓練集的大小;\(J(\theta;x^{(i:i+m)};y^{(i:i+m)})\) 是根據第 \(i\) 個樣本到第 \(i+m\) 個樣本計算出來的損失。
當 \(m==1\) 時,MBGD 變為 SGD;當 \(m==n\) 時,MBGD 變為 BGD。
- 優點
- 收斂更加穩定;
- 可以利用高度優化的矩陣庫來加速計算過程;
- 缺點
- 選擇一個合適的學習率比較困難;
- 相同的學習率被應用到了所有的參數,我們希望對出現頻率低的特征進行大一點的更新,所以我們希望對不同的參數應用不同的學習率;
- 容易被困在鞍點(saddle point);
上圖的紅點就是一個鞍點。上面 MBGD 的 3 個缺點也可以說是 SGD 和 BGD 的 3 個缺點。為了解決這 3 個缺點,研究人員提出了 Momentum、Adagrad、RMSprop、Adadelta、Adam 等優化器。在這介紹這些優化器之前,需要介紹一下指數加權平均(Exponentially Weighted Sum),因為這些改進的優化器或多或少都用了它。
指數加權平均
假設用 \(\theta_t\) 表示一年中第 \(t\) 天的溫度,\(t\in[1,365]\)。我們以天為橫軸,以溫度為縱軸,可以得到下圖
如果我們想要獲得這些數據的局部平均或滑動平均,我們可以設置一個變量 \(v_t\),\(v_t\) 的計算方法如下
當 \(t==1\) 時,我們令 \(v_t=0\)。這樣,\(v_t\) 就約等於第 t 天之前 \(\frac{1}{1-\beta}\) 天的平均溫度(局部平均)。例如,當 \(\beta=0.9\) 時,\(v_t\) 就約等於第 \(t\) 天前 \(\frac{1}{1-0.9}=10\) 天的平均溫度。我們計算出 \(v_t\) 可以得到下圖中的紅色曲線
可以看到,\(v_t\) 對原始數據做了平滑,降低了原始數據的震盪程度。
當我們將 \(\beta\) 設為 0.98 並計算 \(v_t\),可以得到下圖中的綠色曲線
偏差修正
當我們將 \(\beta\) 設為 0.98 並使用公式 \(v_t = \beta v_{t-1} + (1-\beta)\theta_t\) 計算 \(v_t\) 並將其畫在坐標系中,我們得到的其實不是上圖中的綠色曲線,而是下圖中的紫色曲線
可以看到,紫色曲線在后半段和藍色曲線是重合的,前半段有一些偏差,而且紫色曲線的剛開始時非常接近於 0 的,原因是我們設置 \(v_1=0\),所以剛開始的 \(v_t\) 會比較接近 0,也就不能代表前 \(\frac{1}{1-\beta}\) 天的平均溫度。為了修正這個偏差,我們對 \(v_t\) 將縮放為 \(\frac{v_t}{1-\beta^t}\),這樣 t 比較小時分母會是一個小於 1 的小數,能對 \(v_t\) 進行放大;隨着 \(t\) 的增大,分母會越來越接近 1,\(\frac{v_t}{1-\beta^t}\) 也就變成了 \(v_t\)。所以上圖中,紫色曲線和綠色曲線在后半段重合。
指數加權平均減小了原始數據的震盪程度,能對原始數據起到平滑的效果。
Momentum
假設模型在時間 \(t\) 的梯度為 \(\Delta J(\theta)\),則 Momentum 的梯度更新方法如下
其中,\(v_t\) 就是模型前 \(\frac{1}{1-\beta}\) 步梯度的平均值,\(\beta\) 通常設為 0.9,\(\alpha\) 為學習率。
也可以換一種寫法,就是將 \((1-\beta)\) 這一項去掉
第一種寫法更容易理解,所以下面的公式都采用第一種寫法。
在上圖中,左圖是不使用 Momentum 的 SGD,而右圖是使用 Momentum 的 SGD。可以看到,Momentum 通過對前面一部分梯度的指數加權平均使得梯度下降的過程更加平滑,減少了震盪,收斂也比普通的 SGD 更快。
NAG
NAG(Nesterov Accelerated Gradient) 對 Momentum 進行了輕微的修改
也就是,在進行梯度更新前,我們先看一下 Momentum 指向的位置,然后在 Momentum 指向的位置計算梯度並進行更新。如下圖
有很多優化器的名稱中包含 Ada ,Ada 的含義是 Adaptive,代表“自適應性的”。名稱中帶有 Ada 的優化器一般意味着能夠自動適應(調節)參數的學習率。
Adagrad
在我們訓練模型的初期我們的學習率一般比較大,因為這時我們的位置離最優點比較遠;當訓練快結束時,我們通常會降低學習率,因為訓練快結束時我們離最優點比較近,這時使用大的學習率可能會跳過最優點。Adagrad 能使得參數的學習率在訓練的過程中越來越小,具體計算方法如下:
其中,\(g_t\) 是模型在 \(t\) 時刻的梯度,\(\sum_tg_t^2\) 是模型前 t 個時刻梯度的平方和,\(\epsilon\) 防止分母為 0,一般將 \(\epsilon\) 設為一個很小的數,例如 \(10^{-8}\)。在訓練的過程中,\(\sqrt{\sum_tg_t^2+\epsilon}\) 會越來越大,\(\frac{\eta}{\sqrt{\sum_tg_t^2+\epsilon}}\) 會越來越小,所以學習率也會越來越小。\(\eta\) 通常設為 0.01。
- 優點
- 自動調節參數的學習率;
- 缺點
- 學習率下降會比較快,可能造成學習提早停止;
Adadelta
Adadelta 對 Adagrad 做了輕微的修改,使其比 Adagrad 更加穩定。Adadelta 的計算方法如下:
其中,\(E[g^2]_t\) 表示前 \(t\) 個梯度平方和的期望,也就是梯度平方和的指數加權平均。Adadelta 把 Adagrad 分母中的梯度平方和換成了梯度平方的指數加權平均,這使得 Adadelta 學習率的下降速度沒有 Adagrad 那么快。
RMSprop
RMSprop 的全稱是 Root Mean Squre propogation,也就是均方根(反向)傳播。RMSprop 可以看做是 Adadelta 的一個特例
Adadelta 中使用了上式來計算 \(E[g_t^2]\)。當參數 \(\beta=0.5\) 時,\(E[g_t^2]\) 就變成了梯度平方和的平均數,再求根的話,就變成了 RMS,也就是
RMSprop 中參數的更新方法為
Adam
Adam 的全稱是 Adaptive Moment Estimation,其可看作是 Momentum + RMSprop。Adam 使用梯度的指數加權平均(一階矩估計)和梯度平方的指數加權平均(二階矩估計)來動態地調整每個參數的學習率。
其中,\(m_t、n_t\) 分別是梯度的指數加權平均(一階矩估計)和梯度平方的指數加權平均(二階矩估計)。然后,對\(m_t\) 和 \(n_t\) 進行偏差修正
\(m_t、n_t\) 分別是梯度的一階矩估計和二階矩估計,可以看做是對期望 \(E[g]_t\) 和 \(E[g^2]_t\) 的估計。通過偏差修正,\(\hat m_t\) 和 \(\hat n_t\) 可以看做是為期望的無偏估計。最后,梯度的更新方法為
在使用中,\(\beta\) 通常設為 0.9,\(\gamma\) 通常設為 0.999,\(\epsilon\) 通常設為 \(10^{-8}\)。
參考
1、ruder.io/optimizing-gradient-descent/
2、towardsdatascience.com/stochastic-gradient-descent-with-momentum-a84097641a5d
3、akyrillidis.github.io/notes/AdaDelta
4、zhuanlan.zhihu.com/p/22252270
5、jiqizhixin.com/graph/technologies/173c1ba6-0a13-45f6-9374-ec0389124832
6、https://www.cnblogs.com/guoyaohua/p/8542554.html
7、吳恩達《深度學習》課程:https://www.bilibili.com/video/BV1gb411j7Bs?p=60