很早之前看到這篇文章的時候,覺得這篇文章的思想很朴素,沒有讓人眼前一亮的東西就沒有太在意。之后讀到很多Multi-Agent或者並行訓練的文章,都會提到這個算法,比如第一視角多人游戲(Quake III Arena Capture the Flag)的超人表現,NeurIPS2018首屆多智能體競賽(The NeurIPS 2018 Pommerman Competition)的冠軍算法,DeepMind團隊ICLR 2019 conference paper的2V2足球,甚至星際爭霸II里的AlphaStar,都運用了類似方法。所以這里又回過頭記錄一下。
目錄
一、摘要
文章提出了一種簡單的異步優化方法PBT(population based training),主要用來自適應調節超參數。通常的深度學習,超參數都是憑經驗預先設計好的,會花費大量精力且不一定有好的效果,特別是在深度強化學習這種非靜態(non-stationary)的環境中,要想得到SOTA效果,超參數還應隨着環境變化而自適應調整,比如探索率等等。這種基於種群(population)的進化方式,淘汰差的模型,利用(exploit)好的模型並添加隨機擾動(explore)進一步優化,最終得到最優的模型。作者分別從強化學習,監督學習,GAN三個方面做實驗,論證了這個簡單但有效的算法。
作者認為本文主要做了三點改進:(a)訓練過程超參數的自動選擇。(b)模型的在線淘汰和選擇,讓計算資源最大化用在更有希望的模型上(promising models)。(c)超參數在線自適應調節,以適應非靜態場景的超參數規划調節(hyperparameter schedules)。
二、效果展示
- GAN & RL
左邊的gif是GAN在CIFAR-10上的效果,右邊是Feudal Networks(FuN)在 Ms Pacman上的效果。
圖中紅色的點是隨機初始化的模型,也就是所謂的population。再往后,黑色的分支就是效果很差的模型,被淘汰掉。藍色的分支表示效果一直在提升的模型,最終得到的藍色點就是最優的模型。不得不說,DeepMind這可視化效果做的,真的強。
三、方法細節
-
問題分析
神經網絡的訓練受模型結構、數據表征、優化方法等的影響。而每個環節都涉及到很多參數(parameters)和超參數(hyperparameters),對這些參數的調節決定了模型的最終效果。通常的做法是人工調節,但這種方式費時費力且很難得到最優解。
兩種常用的自動調參的方式是並行搜索(parallel search)和序列優化(sequential optimisation)。並行搜索就是同時設置多組參數訓練,比如網格搜索(grid search)和隨機搜索(random search)。序列優化很少用到並行,而是一次次嘗試並優化,比如人工調參(hand tuning)和貝葉斯優化(Bayesian optimisation)。並行搜索的缺點在於沒有利用相互之間的參數優化信息。而序列優化這種序列化過程顯然會耗費大量時間。
還有另一個問題是,對於有些超參數,在訓練過程中並不是一直不變的。比如監督訓練里的學習率,強化學習中的探索度等等。通常的做法是給一個固定的衰減值,而在強化學習這類問題里還會隨不同場景做不同調整。這無疑很難找到一個最優的自動調節方式。 -
具體方法
作者提出了一種很朴素的思想,將並行優化和序列優化相結合。既能並行探索,同時也利用其他更好的參數模型,淘汰掉不好的模型。
如圖所示,(a)中的序列優化過程只有一個模型在不斷優化,消耗大量時間。(b)中的並行搜索可以節省時間,但是相互之間沒有任何交互,不利於信息利用。(c)中的PBT算法結合了二者的優點。
首先PBT算法隨機初始化多個模型,每訓練一段時間設置一個檢查點(checkpoint),然后根據其他模型的好壞調整自己的模型。若自己的模型較好,則繼續訓練。若不好,則替換(exploit)成更好的模型參數,並添加隨機擾動(explore)再進行訓練。其中checkpoint的設置是人為設置每過多少step之后進行檢查。擾動要么在原超參數或者參數上加噪聲,要么重新采樣獲得。作者還寫了幾個公式來規范說明這個問題,看起來逼格更高一點,我個人覺得沒有必要再寫在這里了。 -
偽代碼
偽代碼非常清楚明白。
其中\(\theta\)表示網絡參數,\(h\)表示超參數,\(p\)表示當前模型好壞的指標,\(t\)表示當前第\(t\)代模型(這里說成step應該更准確,多個step之后才生產一代模型,之前理解有點偏差)。整個原理其實和進化算法很像,也和探索利用(exploration vs exploitation)的折中取舍(trade-off)很像。有疑問可以留言交流。
四、實驗結果
-
Toy example
作者舉了一個小例子來說明PBT算法的好處,雖然有點牽強,但是也有一定道理。
作者假設了一個優化函數:\(Q(\theta)=1.2-(\theta_0^2+\theta_1^2)\),目標是求該函數的最大值。我們不知道具體函數,只知道該函數的形式是\(\hat{Q}(\theta|h)=1.2-(h_0\theta_0^2+h_1\theta_1^2)\),其中\(h_0,h_1\)是超參數,\(\theta_0,\theta_1\)是參數。作者對比了PBT,只有替換(exploit)的PBT,只有加隨機擾動(explore)的PBT和網格搜索。作者設置了只有兩個worker的PBT算法,即初始化兩個模型。其中,參數初始化為\(\theta=[0.9,0.9]\),超參數分別設置為\(h=[1,0]\)和\(h=[0,1]\)。每更新5步設置一個checkpoint。
從上圖可以看出,結果顯然是PBT效果好。作者舉的這個例子比較極端,不過也確實能說明一些道理。就是說在訓練過程中超參數也需要不斷修正以找到最優值,而PBT算法剛好可以做到這一點。 -
其他環境效果展示
作者還在一些具體場景上做了實驗,比如強化學習,機器翻譯,對抗網絡等等。這里貼出部分結果,詳細參看原文。- 效果提升展示
- baseline曲線對比
- 對照實驗(ablation experiments)
- 效果提升展示
五、總結
這篇文章思想簡單,效果不錯,實驗結果也在情理之中。除了算法,其算力起到了很重要的作用。比如RL的實驗里worker數量是10-80個,MT里是32個,GAN里是45個,這個算力普通實驗室要做類似工作代價還是比較高的。不過在當前的大環境下,沒有算力確實是寸步難行,特別是RL。