SWATS算法剖析(自動切換adam與sgd)


SWATS算法剖析(自動切換adam與sgd)

SWATS是ICLR在2018的高分論文,提出的一種自動由Adam切換為SGD而實現更好的泛化性能的方法。

論文名為Improving Generalization Performance by Switching from Adam to SGD,下載地址為:

作者指出,基於歷史梯度平方的滑動平均值的如adam等算法並不能收斂到最優解,因此在泛化誤差上可能要比SGD等方法差,因此提出了一種轉換機制,試圖讓算法自動在經過一定輪次的adam學習后,轉而由SGD去執行接下來的操作。

算法本身思想很簡單,就是采用adam這種無需操心learning rate的方法,在開始階段進行梯度下降,但是在學習到一定階段后,由SGD接管。這里前面的部分與常規的adam實現區別不大,重要的是在切換到sgd后,這個更新的learning rate如何計算。 整個算法步驟流程如下:

 

 

熟悉adam的應該能熟悉藍色的部分,這個就是adam的原生實現過程。

作者比較trick的地方就是14行到24行這一部分。這一部分作者做了部分推導,[公式]作為最后的切換learning rate。

算法的整個實現邏輯並不復雜,這里列出自己實現時遇到的一些問題。

填坑 & 問題

  1. 在上面的算法流程第12行,有個[公式],這個在整個流程中未介紹如何實現,本人閱讀論文后,發現應該是學習率衰減的設計。一如很多深度學習策略一樣,這里可以設置經過若干輪迭代后,學習率降為原來的1/N。在論文中,作者使用了在150輪后,將學習速率減少10倍。即[公式]
  2. 上面說了[公式]的更新,我們通過公式推導,其實能發現[公式][公式]有一定的關系,自己代碼實現的版本,發現切換的時機很大程度上和[公式]有關,因為切換涉及到第17行的一個比較過程,[公式][公式]本身都與[公式]相關,當[公式]降一個量級時,[公式]|本身也會更接近[公式]。其有些類似正比關系,因此一般都是在經過一定輪次的衰減后,才能觸發SGD切換時機。這一點目前本人實現驗證是這樣,未深入推理。
  3. 這個[公式]還有個坑,就是實現該算法,開始不太清楚這個k到底指的是epoch,還是指的經歷的batch數量。最后按照常規學習率衰減應該是按照epoch來算的,因此推測其k應該為epoch。
  4. 還有和大坑是[公式]作為學習率,在切換到SGD后應一直不變,該值為標量,因此應該如常用eta等學習率一樣,為正值,因此需要在17行加個約束,即[公式]。(該場景難以復現,之前有次更新發現不設置為正值時,導致切換sgd后准確度大減)

總結

通過若干的對比,該論文變相增加了一些超參數,所以實際使用有待商榷。自己的數據集上經常就在還未滿足切換條件就已經收斂了。 目前已做了相應的實現,放在scalaML中,位置為,使用見。最后想要查看切換過程的話,建議將early_stop設置為false,然后將學習率衰減系數設置低一點。 代碼目前僅支持二分類。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM