鄭重聲明:原文參見標題,如有侵權,請聯系作者,將會撤銷發布!
ICML 2017
Abstract
我們提出了一種與模型無關的元學習算法,從某種意義上說,該算法可與通過梯度下降訓練的任何模型兼容,並適用於各種不同的學習問題,包括分類,回歸和RL。元學習的目標是針對各種學習任務訓練模型,以便僅使用少量訓練樣本即可解決新的學習任務。在我們的方法中,對模型的參數進行了顯式訓練,以使少量梯度步驟和來自新任務的少量訓練數據將對該任務產生良好的泛化性能。實際上,我們的方法訓練出的模型易於微調。我們證明了這種方法在兩個小樣本圖像分類基准上產生了最先進的性能,在小樣本回歸上產生了良好的結果,並加快了使用神經網絡策略進行策略梯度RL的微調。
1. Introduction
快速學習是人類智能的標志,無論是從幾個示例中識別物體還是僅需幾分鍾的經驗即可快速學習新技能。我們的人工智能體應該能夠做到這一點,僅從幾個示例中快速學習與適應,並在有更多數據可用時繼續適應。這種快速且靈活的學習具有挑戰性,因為智能體必須將其先前的經驗與少量的新信息整合在一起,同時還要避免過擬合新數據。此外,先前經驗和新數據的形式將取決於任務。因此,為了獲得最大的適用性,學會學習(或元學習)的機制應針對任務以及完成任務所需的計算形式。
在這項工作中,我們提出了一種通用且與模型無關的元學習算法,從某種意義上說,它可以直接應用於使用梯度下降過程訓練的任何學習問題和模型。我們的重點是深度神經網絡模型,但我們說明了如何通過最少的修改輕松地處理不同的結構和問題設置,包括分類,回歸和策略梯度RL。在元學習中,訓練后的模型的目標是從少量新數據中快速學習新任務,並且元學習器對模型進行訓練,使其能夠學習大量不同的任務。我們方法的主要思想是訓練模型的初始參數,以便在通過一個或多個梯度步驟更新參數后,該模型在新任務上具有最優性能,而該梯度步驟是根據該新任務中的少量數據計算得出的。與先前學習更新函數或學習規則的元學習方法(Schmidhuber, 1987; Bengio et al., 1992; Andrychowicz et al., 2016; Ravi&Larochelle, 2017)不同,我們的算法不會擴展學習參數的數量,也不對模型架構施加約束(例如,通過要求使用循環模型(Santoro et al., 2016)或孿生網絡(Koch, 2015)),它可以很容易地與全連接,卷積或循環神經網絡組合。它也可以用於多種損失函數,包括可微的監督損失和不可微的RL目標。
從特征學習的角度來看,訓練模型參數以使幾個梯度步驟甚至單個梯度步驟可以在新任務上產生良好結果的過程可以從構建一個廣泛適用於許多任務的內部表征的角度來看。如果內部表征適合許多任務,則只需微調參數(例如,通過主要修改前饋模型中的頂層權重)即可產生良好的效果。實際上,我們的過程針對易於快速調整的模型進行了優化,從而可以在合適的空間自適應以進行快速學習。從動態系統的角度來看,我們的學習過程可以看作是使新任務的損失函數對參數的敏感性最大化:當敏感性高時,較小的參數局部更改會導致任務損失的較大改進 。
這項工作的主要貢獻是一種用於元學習的簡單模型和與任務無關的算法,該算法可訓練模型的參數,以使少量的梯度更新可快速學習新任務。我們在不同的模型類型(包括全連接和卷積網絡)上以及在幾個不同的領域(包括小樣本回歸,圖像分類和RL)中演示了該算法。我們的評估表明,我們的元學習算法與專門為監督分類而設計的最新one-shot學習方法相比具有優勢,同時使用的參數更少,但它也可以很容易地應用於回歸分析並可以加速RL。在存在任務可變性的情況下,其性能遠勝於直接預訓練的初始化。
2. Model-Agnostic Meta-Learning
我們的目標是訓練可以實現快速適應的模型,這個問題通常被確定為小樣本學習。在本節中,我們將定義問題的設置並介紹算法的通用形式。
2.1. Meta-Learning Problem Set-Up
小樣本元學習的目標是訓練僅需幾個數據點和訓練迭代就可以快速適應新任務的模型。為此,在元學習階段對一組任務訓練模型或學習器,以使訓練后的模型僅使用少量示例或試驗即可快速適應新任務。實際上,元學習問題將整個任務視為訓練示例。在本節中,我們以通用方式將這種元學習的問題形式化,包括不同學習領域的簡短示例。我們將在第3節中詳細討論兩個不同的學習領域。
我們考慮一個模型(表示為 f ),該模型將觀測值x映射到輸出a。在元學習期間,訓練模型以使其能夠適應大量或無限數量的任務。由於我們希望將我們的框架應用於從分類到RL的各種學習問題,因此我們在下面介紹學習任務的通用概念。形式上,每個任務T = { L(x1, a1, ... , xH, aH), q(x1), q(xt+1|xt, at), H }由損失函數L,初始觀測值的分布q(x1),轉換分布q(xt+1|xt, at)和回合長度H組成。在i.i.d.監督學習問題,長度H = 1。該模型可以通過在每個時間 t 選擇一個輸出at來生成長度為H的樣本。損失L(x1, a1, ... , xH, aH) → R提供特定於任務的反饋,在Markov決策過程中可能以誤分類損失或成本函數的形式出現。
在我們的元學習場景中,我們考慮希望模型能夠適應的任務上的分布p(T)。在K-shot學習設置中,訓練模型以僅從qi和Ti生成的反饋LTi的K個樣本中學習從p(T)提取的新任務Ti。在元訓練期間,從p(T)采樣任務Ti,用K個樣本訓練模型,並從Ti的相應損失LTi收到反饋,然后在Ti的新樣本上進行測試。然后通過考慮來自qi的新數據的測試誤差如何相對於參數變化來改進模型 f。實際上,采樣任務Ti上的測試誤差充當了元學習過程的訓練誤差。在元訓練結束時,會從p(T)中采樣新任務,並根據模型在K個樣本中學習后的性能來衡量元性能。通常,用於元測試的任務會在元訓練期間保留。
2.2. A Model-Agnostic Meta-Learning Algorithm
與先前的研究相反,后者試圖訓練可吸收整個數據集的RNN (Santoro et al., 2016; Duan et al., 2016b)或可以在測試時與非參數化方法結合的特征嵌入(Vinyals et al., 2016; Koch, 2015),我們提出了一種方法,該方法可以通過元學習來學習任何標准模型的參數,從而為快速適應該模型做准備。這種方法背后的直覺是,某些內部表征比其他內部表征更易於遷移。例如,神經網絡可能會學習廣泛適用於p(T)中所有任務的內部特征,而不是單個任務。我們如何鼓勵這種通用表征的出現?我們針對此問題采取了明確的方法:由於將對新任務使用基於梯度的學習規則來對該模型進行微調,因此我們將以這樣一種方式來學習模型:基於梯度的學習規則可以使從p(T)抽取的新任務快速發展,而不要過擬合。實際上,我們的目標是找到對任務的變化敏感的模型參數,當沿該損失的梯度方向變化時,參數的微小變化將對從p(T)中抽取的任何任務的損失函數產生較大的改進(參見圖1)。除了假設模型由某些參數向量θ參數化之外,我們不對模型的形式進行任何假設,並且損失函數在θ中足夠平滑,因此可以使用基於梯度的學習技術。
形式上,我們考慮一個由參數化函數fθ表示的模型。當適應新任務Ti時,模型的參數θ變為。在我們的方法中,使用任務Ti上的一個或多個梯度下降更新來計算更新的參數矢量
。例如,當使用一個梯度更新時,
步長α可以固定為超參數或由元學習得到。為了簡化表示,本節的其余部分將考慮一個梯度更新,但是使用多個梯度更新是一個直接的擴展。
通過針對跨越從p(T)采樣的任務關於θ優化的性能來訓練模型參數。更具體地,元目標如下:
請注意,對模型參數θ執行元優化,而使用更新后的模型參數θ'計算目標。實際上,我們提出的方法旨在優化模型參數,以使新任務上的一個或少量梯度步驟將對該任務產生最大的有效行為。
跨任務的元優化是通過隨機梯度下降(SGD)執行的,因此模型參數θ的更新如下:
其中β是元步長。一般情況下,完整算法在算法1中概述。
MAML元梯度更新涉及二階梯度。在計算上,這需要額外的反向遍歷 f 來計算Hessian向量乘積,這由標准深度學習庫(如TensorFlow)支持(Abadi et al., 2016)。在我們的實驗中,我們還進行了與去除該反向通過並使用一階近似的比較,這將在5.2節中討論。
3. Species of MAML
在本節中,我們將討論用於監督學習和RL的元學習算法的特定實例。這些域在損失函數的形式以及任務如何生成數據並將其呈現給模型方面有所不同,但是在兩種情況下都可以應用相同的基本適應機制。
3.1. Supervised Regression and Classification
在監督任務的領域中,小樣本學習已有大量研究,其目標是使用相似任務的先驗數據進行元學習,從而僅從該任務的幾個輸入/輸出對中學習新函數。例如,目標可能是在僅看到一個或幾個Segway實例之后,使用先前已看到許多其他類型目標的模型對Segway圖像進行分類。同樣,在小樣本回歸中,目標是在對許多具有相似統計特性的函數進行訓練后,僅從該函數采樣的幾個數據點中預測連續價值函數的輸出。
(省略)
3.2. Reinforcement Learning
在RL中,小樣本元學習的目標是使智能體僅使用少量測試設置經驗就可以快速獲取新測試任務的策略。新任務可能涉及達成新目標或在新環境中成功達成先前訓練的目標。例如,智能體可能會學會快速找出迷宮的導航方式,以便面對新的迷宮時,可以確定如何僅用幾個樣本就能可靠地到達出口。在本節中,我們將討論如何將MAML應用於RL的元學習。
每個RL任務Ti包含初始狀態分布qi(x1)和轉換分布qi(xt+1|xt, at),並且損失LTi對應於(負)獎勵函數R。因此,整個任務是一個馬爾可夫決策過程(MDP)的范圍為H,允許學習者查詢有限數量的樣本軌跡以進行小樣本學習。MDP的任何方面都可能因p(T)中的任務而異。正在學習的模型fθ是在每個時間步驟t∈{1, ... , H}從狀態xt映射到動作at分布的策略。任務Ti和模型fΦ的損失如下:
在K-shot RL中,可以使用fθ和任務Ti的K部署(x1, a1, ... , xH)和相應的獎勵R(xt, at)來適應新任務Ti。由於期望獎勵通常由於未知的動態而無法區分,因此我們使用策略梯度方法來估計模型梯度更新和元優化的梯度。由於策略梯度是基於策略的算法,因此在適應fθ期間的每個額外梯度步驟都需要來自當前策略的新樣本。我們在算法3中詳細介紹了該算法。該算法與算法2的結構相同,主要區別在於步驟5和步驟8需要從與任務Ti對應的環境中采樣軌跡。該方法的實際實現還可以使用最近針對策略梯度算法提出的各種改進,包括狀態或動作相關的基准和信任區域(Schulman et al., 2015)。
4. Related Work
5. Experimental Evaluation
我們的實驗評估的目的是回答以下問題:(1) MAML是否可以快速學習新任務?(2) MAML是否可以用於多個不同領域的元學習,包括監督回歸,分類和RL?(3) 通過MAML學到的模型是否可以通過額外梯度更新和/或示例來繼續改進?
我們考慮的所有元學習問題都需要在測試時適應新任務。在可能的情況下,我們將結果與一個oracle進行比較,該oracle將任務的身份(作為問題相關的表征)作為額外輸入,作為模型性能的上限。所有實驗均使用TensorFlow(Abadi et al., 2016)進行,它允許在元學習過程中通過梯度更新自動進行微分。該代碼可在線獲得1。
1 用於回歸和監督實驗的代碼在github.com/cbfinn/maml上,用於RL實驗的代碼在github.com/cbfinn/maml_rl上
5.1. Regression
5.2. Classification
5.3. Reinforcement Learning
為了評估關於RL問題的MAML,我們根據rllab基准套件中的仿真連續控制環境構造了幾組任務(Duan et al., 2016a)。我們在下面討論各個領域。在所有領域中,由MAML訓練的模型都是一個神經網絡策略,具有兩個大小為100的隱含層,具有ReLU非線性。使用原始策略梯度(REINFORCE)(Williams, 1992)計算梯度更新,我們使用信任區域策略優化(TRPO)作為元優化器(Schulman et al., 2015)。為了避免計算三階導數,我們使用有限差分來計算TRPO的Hessian矢量積。對於學習和元學習更新,我們使用Duan et al. (2016a)提出的標准線性特征基准,它針對批中的每個采樣任務在每次迭代時分別擬合。我們將三種基准模型進行比較:(a) 在所有任務上預先訓練一個策略,然后進行微調;(b) 從隨機初始化的權重中訓練一個策略;(c) 一個接收任務參數作為輸入的oracle策略,對於以下任務,該輸入對應於智能體的目標位置,目標方向或目標速度。(a)和(b)的基准模型通過梯度下降和手動調整的步長進行微調。可以在sites.google.com/view/maml上查看學習到的策略的視頻。
2D Navigation. 在我們的第一個元RL實驗中,我們研究了一組任務,其中點智能體必須以二維方式移動到不同的目標位置,並為單位平方內的每個任務隨機選擇。觀測值是當前的2D位置,並且動作與速度指令相對應,該速度指令被裁剪為[-0.1, 0.1]。獎勵是到目標的負平方距離,並且當智能體在目標的0.01內或在H = 100的范圍內時,回合終止。使用MAML對策略進行了訓練,以在使用20條軌跡更新1個策略梯度后最大化性能。附錄A.2中針對此問題和以下RL問題的額外超參數設置。在我們的評估中,我們將適應性與一項新任務的適應性進行了比較,該任務具有多達4個梯度更新,每個梯度更新有40個樣本。圖4中的結果顯示了使用MAML初始化的模型,對同一組任務的常規預訓練,隨機初始化以及接收目標位置作為輸入的oracle策略的適應性能。結果表明MAML可以學習一個模型,該模型可以在單個梯度更新中更快地適應,並且可以通過額外更新繼續改進。
Locomotion. 為了研究MAML如何更好地解決更復雜的深度RL問題,我們還使用MuJoCo仿真器研究了對高維運動任務的適應性(Todorov et al., 2012)。這些任務需要兩個仿真機器人——平面獵豹和3D四足動物("螞蟻")在特定方向或特定速度下運行。在目標速度實驗中,獎勵是智能體當前速度和目標之間的負絕對值,對於獵豹,均勻隨機地在0.0和2.0之間選擇值,而螞蟻在0.0和3.0之間。在目標方向實驗中,獎勵是向前或向后速度的大小,是針對p(T)中的每個任務隨機選擇的。范圍是H = 200,對於所有問題,每個梯度步驟都有20個部署,但螞蟻前進/后退任務每步需要40個部署。圖5中的結果表明MAML學習了一個模型,該模型甚至可以僅通過單個梯度更新就可以快速適應其速度和方向,並且可以通過更多的梯度步驟來不斷改進。結果還表明,在這些具有挑戰性的任務上,MAML初始化明顯優於隨機初始化和預訓練。實際上,在某些情況下,預訓練比隨機初始化要差,這是之前的RL工作中觀察到的事實(Parisotto et el., 2016)。
6. Discussion and Future Work
我們介紹了一種基於梯度下降來學習易適應的模型參數的元學習方法。我們的方法有很多好處。它很簡單,並且不會引入任何學習到的元學習參數。它可以與任何適合基於梯度訓練的模型表征以及任何可微的目標(包括分類,回歸和RL)結合使用。最后,由於我們的方法僅產生權重初始化,因此盡管我們展示了分類的最新結果(每個類僅包含一個或五個示例),但可以使用任意數量的數據和任意數量的梯度步驟進行自適應。我們還表明,我們的方法可以使用策略梯度和非常適量的經驗來適應RL智能體。
重用過去任務中的知識可能是產生高容量可擴展模型(例如深度神經網絡)的關鍵要素,這些模型適合使用小數據集進行快速訓練。我們認為,這項工作是邁向一種簡單通用的元學習技術的第一步,該技術可以應用於任何問題和模型。在該領域的進一步研究可以使多任務初始化成為深度學習和RL的標准要素。
A. Additional Experiment Details
在本節中,我們提供實驗設置和超參數的其他詳細信息。
A.1. Classification
A.2. Reinforcement Learning
在所有RL實驗中,使用α = 0.1的單個梯度步驟對MAML策略進行了訓練。在評估過程中,我們發現在第一個梯度步驟后將學習率減半產生了出色的性能。因此,自適應期間的步長對於第一步設置為α = 0.1,對於以后的所有步驟設置為α = 0.05。對於每個域,手動調整了基准方法的步長。在2D導航中,我們使用的元批處理大小為20;在運動問題中,我們使用了40個任務的元批處理大小。對MAML模型進行了多達500次元迭代的訓練,並使用訓練期間平均回報最高的模型進行評估。對於螞蟻目標速度任務,我們在每個時間步驟都添加了正獎勵,以防止螞蟻結束該回合。
B. Additional Sinusoid Results
C. Additional Comparisons
C.1. Multi-task baselines
C.2. Context vector adaptation