Adam和學習率衰減(learning rate decay)


本文先介紹一般的梯度下降法是如何更新參數的,然后介紹 Adam 如何更新參數,以及 Adam 如何和學習率衰減結合。

梯度下降法更新參數

梯度下降法參數更新公式:

\[\theta_{t+1} = \theta_{t} - \eta \cdot \nabla J(\theta_t) \]

其中,\(\eta\) 是學習率,\(\theta_t\) 是第 \(t\) 輪的參數,\(J(\theta_t)\) 是損失函數,\(\nabla J(\theta_t)\) 是梯度。

在最簡單的梯度下降法中,學習率 \(\eta\) 是常數,是一個需要實現設定好的超參數,在每輪參數更新中都不變,在一輪更新中各個參數的學習率也都一樣。

為了表示簡便,令 \(g_t = \nabla J(\theta_t)\),所以梯度下降法可以表示為:

\[\theta_{t+1} = \theta_{t} - \eta \cdot g_t \]

Adam 更新參數

Adam,全稱 Adaptive Moment Estimation,是一種優化器,是梯度下降法的變種,用來更新神經網絡的權重。

Adam 更新公式:

\[\begin{aligned} m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\ v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2} \\ \hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\ \hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \\ \theta_{t+1}&=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t} \end{aligned} \]

在 Adam 原論文以及一些深度學習框架中,默認值為 \(\eta = 0.001\)\(\beta_1 = 0.9\)\(\beta_2 = 0.999\)\(\epsilon = 1e-8\)。其中,\(\beta_1\)\(\beta_2\) 都是接近 1 的數,\(\epsilon\) 是為了防止除以 0。\(g_{t}\) 表示梯度。

咋一看很復雜,接下一一分解:

  • 前兩行:

\[\begin{aligned} m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\ v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2} \end{aligned} \]

這是對梯度和梯度的平方進行滑動平均,使得每次的更新都和歷史值相關。

  • 中間兩行:

\[\begin{aligned} \hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\ \hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \end{aligned} \]

這是對初期滑動平均偏差較大的一個修正,叫做 bias correction,當 \(t\) 越來越大時,\(1-\beta_{1}^{t}\)\(1-\beta_{2}^{t}\) 都趨近於 1,這時 bias correction 的任務也就完成了。

  • 最后一行:

\[\theta_{t+1}=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t} \]

這是參數更新公式。

學習率為 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\),每輪的學習率不再保持不變,在一輪中,每個參數的學習率也不一樣了,這是因為 \(\eta\) 除以了每個參數 \(\frac{1}{1- \beta_2} = 1000\) 輪梯度均方和的平方根,即 \(\sqrt{\frac{1}{1000}\sum_{k = t-999}^{t}g_k^2}\)。而每個參數的梯度都是不同的,所以每個參數的學習率即使在同一輪也就不一樣了。(可能會有疑問,\(t\) 前面沒有 999 輪更新怎么辦,那就有多少輪就算多少輪,這個時候還有 bias correction 在。)

而參數更新的方向也不只是當前輪的梯度 \(g_t\) 了,而是當前輪和過去共 \(\frac{1}{1- \beta_1} = 10\) 輪梯度的平均。

有關滑動平均的理解,可以參考我之前的博客:理解滑動平均(exponential moving average)

Adam + 學習率衰減

在 StackOverflow 上有一個問題 Should we do learning rate decay for adam optimizer - Stack Overflow,我也想過這個問題,對 Adam 這些自適應學習率的方法,還應不應該進行 learning rate decay?

論文 《DECOUPLED WEIGHT DECAY REGULARIZATION》的 Section 4.1 有提到:

Since Adam already adapts its parameterwise learning rates it is not as common to use a learning rate multiplier schedule with it as it is with SGD, but as our results show such schedules can substantially improve Adam’s performance, and we advocate not to overlook their use for adaptive gradient algorithms.

上述論文是建議我們在用 Adam 的同時,也可以用 learning rate decay。

我也簡單的做了個實驗,在 cifar-10 數據集上訓練 LeNet-5 模型,一個采用學習率衰減 tf.keras.callbacks.ReduceLROnPlateau(patience=5),另一個不用。optimizer 為 Adam 並使用默認的參數,\(\eta = 0.001\)。結果如下:


加入學習率衰減和不加兩種情況在 test 集合上的 accuracy 分別為: 0.5617 和 0.5476。(實驗結果取了兩次的平均,實驗結果的偶然性還是有的)

通過上面的小實驗,我們可以知道,學習率衰減還是有用的。(當然,這里的小實驗僅能代表一小部分情況,想要說明學習率衰減百分之百有效果,得有理論上的證明。)

當然,在設置超參數時就可以調低 \(\eta\) 的值,使得不用學習率衰減也可以達到很好的效果,只不過參數更新變慢了。

將學習率從默認的 0.001 改成 0.0001,epoch 增大到 120,實驗結果如下所示:

加入學習率衰減和不加兩種情況在 test 集合上的 accuracy 分別為: 0.5636 和 0.5688。(三次實驗平均,實驗結果仍具有偶然性)

這個時候,使用學習率衰減帶來的影響可能很小。

那么問題來了,Adam 做不做學習率衰減呢?
我個人會選擇做學習率衰減。(僅供參考吧。)在初始學習率設置較大的時候,做學習率衰減比不做要好;而當初始學習率設置就比較小的時候,做學習率衰減似乎有點多余,但從 validation set 上的效果看,做了學習率衰減還是可以有丁點提升的。

ReduceLROnPlateau 在 val_loss 正常下降的時候,對學習率是沒有影響的,只有在 patience(默認為 10)個 epoch 內,val_loss 都不下降 1e-4 或者直接上升了,這個時候降低學習率確實是可以很明顯提升模型訓練效果的,在 val_acc 曲線上看到一個快速上升的過程。對於其它類型的學習率衰減,這里沒有過多地介紹。

Adam 衰減的學習率

從上述學習率曲線來看,Adam 做學習率衰減,是對 \(\eta\) 進行,而不是對 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 進行,但有區別嗎?

學習率衰減一般如下:

  • exponential_decay:
    decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

  • natural_exp_decay:
    decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / decay_steps)

  • ReduceLROnPlateau
    如果被監控的值(如‘val_loss’)在 patience 個 epoch 內都沒有下降,那么學習率衰減,乘以一個 factor
    decayed_learning_rate = learning_rate * factor

這些學習率衰減都是直接在原學習率上乘以一個 factor ,對 \(\eta\) 或對 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 操作,結果都是一樣的。

References

[1] An overview of gradient descent optimization algorithms -- Sebastian Ruder
[2] Should we do learning rate decay for adam optimizer - Stack Overflow
[3] Tensorflow中learning rate decay的奇技淫巧 -- Elevanth
[4] Loshchilov, I., & Hutter, F. (2017). Decoupled Weight Decay Regularization. ICLR 2019. Retrieved from http://arxiv.org/abs/1711.05101


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM