On First-Order Meta-Learning Algorithms
Abstract
本文考慮元學習問題,其中存在任務分布,我們希望得到一個當面對一個從這個分布中采樣的以前未被發現(即以前訓練的時候沒使用過的)的任務時,也能表現良好的agent(即學習很快)。我們分析了一組學習參數初始化的算法,這些算法可以在新的任務上快速微調,僅使用一階導數進行元學習更新。這一族包含和推廣了一階MAML,這是忽略二階導數得到的MAML的近似。它還包括Reptile,即我們在這里介紹的一個新算法,它通過重復采樣一個任務,對它進行訓練,並將初始化移動到該任務的訓練權重。我們擴展了Finn等人的結果,表明一階元學習算法在一些已確立的few-shot分類基准上表現良好,我們提供了理論分析,旨在理解這些算法的工作原理。
1 Introduction
雖然機器學習系統已經在許多任務上超越了人類,但它們通常需要多得多的數據才能達到相同的性能水平。例如,Schmidt等人[17,15]的研究表明,人類主體可以根據一些樣本圖像識別新的對象類別。Lake等人[12]指出,在Frostbite的Atari游戲中,人類新手在15分鍾后就能在游戲中取得顯著進展,但double-dueling-DQN [19]需要超過1000倍的經驗才能獲得相同的分數。
把人類比作從頭開始學習的算法是不完全公平的,因為人類在完成任務時,大腦和DNA中已經編碼了大量的先驗知識。他們不是從零開始學習,而是對一組已有的技能進行微調和重組。上述由Tenenbaum和他的合作者所引用的工作表明,人類的快速學習能力可以被解釋為貝葉斯推理,而開發具有人類水平學習速度的算法的關鍵是使我們的算法更加貝葉斯。然而,在實踐中,開發(從第一原理)利用深度神經網絡並且在計算上可行的貝葉斯機器學習算法是具有挑戰性的。
元學習是最近出現的一種從少量數據中學習的方法。元學習並不是試圖模仿貝葉斯推理(這可能是難以計算的),而是尋求使用任務數據集直接優化快速學習算法。具體地說,我們假設可以訪問任務分布,例如,每個任務都是一個分類問題。從這個分布中,我們對任務的訓練集和測試集進行采樣。我們的算法接受訓練集,它必須產生一個在測試集上平均表現良好的agent。由於每個任務對應一個學習問題,因此在一個任務上表現良好就對應着快速學習。
各種不同的元學習方法已經被提出,每種方法都有其優缺點。其中一種方法是將學習算法編碼在遞歸網絡的權值中,但在測試時不進行梯度下降。該方法是由Hochreiter等人[8]提出的,他們使用LSTMs進行下一步預測,並在近期進行了一系列后續工作,例如,Santoro等人[16]進行few-shot分類,Duan等人[3]進行POMDP設置。
第二種方法是學習網絡的初始化,然后在測試時對新任務進行微調。這種方法的一個經典示例是使用大型數據集(例如ImageNet[2])進行預訓練,並對較小的數據集進行微調(例如不同種類鳥類[20]的數據集)。然而,這種經典的預訓練方法並不能保證學習有利於微調的初始化,而且要獲得良好的性能,需要特別的技巧。最近,Finn等人[4]提出了一種稱為MAML的算法,該算法通過在微調過程微分來直接優化了與初始化相關的性能。在這種方法中,即使在接收到樣本外數據時,學習者也會采用合理的基於梯度的學習算法,從而使其比基於RNN的方法[5]泛化得更好。另一方面,由於MAML需要在優化過程中進行微分,因此它不適用於需要在測試時執行大量梯度步驟的問題。作者還提出了一種稱為一階MAML (FOMAML)的變體,它的定義是忽略二階導數項,避免了這個問題,但代價是丟失一些梯度信息。然而,令人驚訝的是,他們發現在Mini-ImageNet數據集[18]上,FOMAML的表現幾乎和MAML一樣好。(之前的元學習[1,13]的工作就預示了這個結果,在通過梯度下降進行微分時忽略了二階導數,沒有不良影響。)在這項工作中,我們在此基礎上進行了擴展,並探索了基於一階梯度信息的元學習算法的潛力,這是出於對那些使用依賴於高階梯度(如全MAML)的技術過於繁瑣的問題的潛在適用性。
我們的貢獻如下:
- 我們指出,一階MAML[4]的實現比在本文之前被廣泛認為的要簡單。
- 我們介紹了Reptile,一個與FOMAML密切相關的算法,它同樣易於實現。Reptile與聯合訓練(即訓練以減少訓練任務的預期損失)十分相似,它作為一種元學習算法,尤其令人驚訝。與FOMAML不同的是,Reptile不需要為每個任務進行訓練測試分割,這可能會使它在某些情況下成為更自然的選擇。它也與[7]的fast weight/slow weight的舊觀念有關。
- 我們提供了一個適用於一階MAML和Reptile的理論分析,表明它們都優化了任務內的泛化。
- 基於對Mini-ImageNet[18]和Omniglot[11]數據集的經驗評估,我們為最佳實踐提供了一些見解。
2 Meta-Learning an Initialization
我們考慮了MAML[4]的優化問題:找到一組初始參數Φ,使得對於一個具有相應的損失的隨機采樣的任務
,在k次更新后,學習者的損失較小。那就是:
其中是使用從任務
中采樣的數據更新Φ參數k次的操作符。在few-shot學習中,U對應於在從任務
中采樣的數據batches中實現梯度下降或Adam[10]
MAML解決了方程(1)的一個版本,它基於額外的假設:對於給定的任務,內循環優化使用訓練樣本A,而損失使用測試樣本B計算。這樣,MAML優化泛化,類似於交叉驗證。省略上標k,我們把它記為:
MAML的工作方式是通過隨機梯度下降來優化這個損失,即如下的計算:
在等式(4)中,是更新操作
的Jacobian矩陣。
對應於添加一系列梯度向量到初始向量中,即
。(在Adam中,梯度也按元素重新調整,但這不會改變結論。)一階MAML(FOMAML)將這些梯度(二階)看作常量,然而,它使用恆等操作來替換Jacobian
。因此,FOMAML在外循環優化中使用的梯度是
。因此FOMAML能使用特別簡單的方法來實現:(1)采樣任務
;(2)使用更新操作符,產生
,這是在訓練集A上得到的;(3)計算在
的梯度
,這是在測試集B上得到的;(4)將gFOMAML添加到外循環操作器中,更新參數
3 Reptile
在本節中,我們描述了一個新的一階基於梯度的元學習算法,稱為Reptile。和MAML一樣,Reptile學習神經網絡模型參數的初始化,這樣當我們在測試時優化這些參數時,學習是快速的 —— 即該模型從測試任務的少量示例中歸納而來。Reptile算法如下:
該方法與MAML的不同主要在於最后一步更新Φ上
MAML的訓練可以分為兩個層次:內層優化和外層優化,內層優化就與普通的訓練一樣,假設網絡初始參數為θ0,在數據集A(訓練數據)上采用SGD的方式進行訓練后得到參數θ′。如果是普通的訓練,那么就會接着采樣一個數據集B(測試數據),然后以θ′作為初始參數繼續訓練了。MAML同樣采集一個數據集B,然后用在數據集A上訓練得到的模型fθ′處理數據集B上的樣本,並計算損失。不同的是,MAML利用該損失計算得到的梯度對θ0進行更新:
也就是說MAML的目標是訓練得到一個好的初始化參數θ,使其能夠在處理其他任務時很快的收斂到一個較好的結果。在梯度計算過程中會涉及到二階導數計算,MAML利用一階導數近似方法(FOMAML)進行處理,發現結果相差並不大,但計算量會減少很多
回到本文,本文提出的算法就是在FOMAML,進一步簡化參數更新的方式,甚至連損失梯度都不需要計算了,直接利用θ0−θ′(即下面的)作為梯度對參數進行更新,即:
可能有人會覺得這樣做,不是相當於退化成普通的訓練過程了嗎,因為θ′還是利用SGD方式得到的,然后讓θ0沿着θ0−θ′的方向更新,就得到θ1。如果說在訓練數據集A中只有一個訓練樣本,或者說只經過一個batch的訓練,那么本文的算法的確會退化為普通的SGD訓練,但如果每個數據集都進行不止一個Batch的訓練,二者就不相同了。
如下圖所示:
如果k=0,那么Reptile的確就等於普通的SGD訓練了
在最后一步,不是簡單地在方向更新Φ,我們將
看作一個梯度,並將其插入如Adam[10]的自適應算法中。(實際上,我們將在Section 5.1中討論,將Reptile梯度定義為
會更自然,其中α是使用在SGD操作中的步長。)我們還定義了該算法的並行或batch版本,能在每個迭代中評估n個任務並更新初始化為:
其中,是第i個任務的更新參數
該算法看起來與在期望損失上聯合訓練極其相似。甚至,如果我們定義U為一步的梯度下降(k=1),那么該算法對應於在期望損失上的隨機梯度下降:
然而,如果我們在局部最小化中執行多個梯度更新(k>1),那么該期望更新並不對應於在期望損失
上梯度下降一步(即當k>1時,U的期望更新將不等於損失函數期望更新,而是將包含損失函數的二次甚至更高階微分項,Reptile的收斂點與最小化E(L)不同)。相反,該更新包含來自
的二階和更高階導數的重要部分,如我們將在Section 5.1中分析的一樣。甚至,Reptile收斂到了一個與最小的預期損失
十分不同的結果
其他部分如步長參數ε和任務采樣,Reptile的batched版本都與SimuParallelSGD[21]算法相似。SimuParallelSGD是一個communication-efficient的分布式優化方法,其中worker在本地進行梯度更新,且不經常平均他們的參數,而不是使用平均梯度的標准方法。
4 Case Study: One-Dimensional Sine Wave Regression
作為一個簡單的案例研究,讓我們考慮一維正弦波回歸問題,它是對Finn等人[4]做了一些修改。這個問題是有指導意義的,因為通過設計,聯合訓練不能學習一個非常有用的初始化;然而,元學習方法可以。
- 任務
使用正弦波函數
的振幅a和相位b來定義。通過采樣
得到任務分布
- 采樣p個點
- learner找到
並預測整個函數f(x)
- 損失是在整個[-5,5]間隔中使用的L2損失
我們用50個等距點x來計算這個積分
首先注意到由於隨機相位b,平均函數在任何地方都是0,即。因此在期望損失
上訓練是無用的,因為損失會使用零函數f(x) = 0來最小化
另一方面,MAML和Reptile給了我們在任務上訓練前的輸出近似為f(x) = 0的初始化,但是在采樣點
上訓練后的網絡的內部特征表征近似於目標函數
。這個學習過程在下圖中展示。從圖1可以看出,經過Reptile訓練后,網絡可以快速收斂到一個采樣的正弦波,並推斷出遠離采樣點的值。作為比較,我們還展示了MAML和一個隨機初始化的網絡在同一任務上的行為。
5 Analysis
在這個部分,我們提供了兩個可替換的有關為什么Reptile能夠運作的解釋
5.1 Leading Order Expansion of the Update
在這里,我們將使用泰勒級數展開來近似Reptile和MAML執行的更新。我們將說明這兩種算法包含相同的主導階項:第一個項最小化預期損失(聯合訓練),第二個和更有趣的項最大化任務內泛化。具體來說,它最大化了來自同一任務的不同小批量的梯度之間的內積。如果不同batches的梯度有正的內積,那么在一個batches上采取的梯度step可以改善另一個batches的性能。
與MAML的討論和分析不同,我們不考慮每個任務的訓練集和測試集;相反,我們只假設每個任務給我們一個k個損失函數序列L1,L2,…,Lk;例如,不同minibatches的分類損失。我們將使用以下定義:
k個損失函數序列L1,L2,…,Lk表示的是k個steps求得的損失,比如一開始參數為Φ1,然后經過一個step,參數變為Φ2,此時對應的損失為L2;再經過一次step,參數變為Φ3,此時對應的損失為L3;...; 經過k-1次后參數就變為Φk,此時對應的損失為Lk。
首先計算如下帶有的SGD梯度:
這個式子的作用就是如何將參數為Φi的一階損失的計算變為由Φ1表示的式子,說明SGD是如何計算的梯度,得到具有一階和二階的式子
這一部分就是batch_size個任務訓練數據如果使用SGD計算梯度的過程
接下來,我們將粗略估計MAML的梯度。定義為更新minibatch i的參數向量的操作符:
這一部分表示的是在測試數據上訓練,全局更新參數的步驟
這表示的是在訓練某個任務時,會訓練K個steps,一個step對應一個minibatch,然后更新一次參數,這樣一步步運行第k個step計算的損失就是Lk(Φk),然后該損失對Φ1求導,這樣就能夠得到訓練這個任務的梯度更新,即從初始化參數Φ1的變化
接下來則是解釋主導階:
這樣梯度就變為了只與初始參數Φ1相關的一階、二階式子
為便於說明,讓我們考慮k = 2的情況,稍后我們將給出一般公式。三種算法的梯度為:
gMAML的結果是使用k=2代入等式(24)得到的
gReptile的結果就是使用k=2代入等式(16)得到的
可見Reptile僅在訓練數據中使用SGD就能夠得到和MAML類似的效果,因此后面直接使用看作梯度輸入優化器Adam優化參數即可
其中如這樣的項就是leading-order;
這樣的項就是次leading-order
正如我們將在下一段中所展示的,像這樣的項可以使計算在不同小批量上的梯度之間的內積最大化,而像
這樣的單一梯度項則使我們在聯合訓練問題中達到最小值。
當我們在minibatch采樣下得到三種算法梯度gFOMAML, gReptile, and gMAML的期望時,我們僅留下兩類叫做AvgGrad和AvgGradInner的項。在下面的等式中,表示我們在任務
得到的期望,兩個minibatched分別定義為L1和L2
- AvgGrad被定義為期望損失的梯度:
(−AvgGrad)是能夠將帶到“聯合訓練”問題的最小值的方向;任務的期望損失
更有意思的項是AvgGradInner,定義如下:
因此(−AvgGradInner)是增加給定任務的不同minibatches間梯度內積的方向,可改善泛化能力
回想我們梯度表達式,我們能得到如下用於meta-gradients的表達式,使用的是k=2的SGD,三種算法的梯度期望為:
實際上,這三個梯度表達式首先都會將我們帶到任務期望損失的最小值,然后更高階的AvgGradInner項能通過最大化給定任務梯度間的內積來實現更快的學習
最后,我們能夠擴展這三個計算到通用情況上,即 k>=2:
當k=2時,AvgGradInner項和AvgGradInner項的相關系數比率是。可是,該比率將隨着stepsize α和迭代次數k線性增加。注意泰勒級數近似只對小的αk可用。
5.2 Finding a Point Near All Solution Manifolds
在這里,我們認為Reptile收斂於一個解Φ,這個解接近(歐幾里得距離)每個任務的最優解的manifold。這是一種非正式的論證,不像前面的泰勒級數分析那樣嚴肅。
讓Φ表示網絡的初始化,並讓表示任務
的最優參數集。我們希望找到使所有任務的距離
小的Φ:
我們展示了Reptile對應於在該目標上實行SGD
給定non-pathological集,然后對於幾乎所有點
,平方距離
的梯度是
,其中
是Φ到S的投射(最近點)。因此:
Reptile的每個迭代對應於采樣一個任務和實行一個隨機梯度下降:
實際上,我們並不能計算,其定義為
的一個最小值。但是我們能夠使用梯度下降局部最小化該損失。因此,在Reptile中,我們在一開始用Φ初始化的
上使用k個steps的梯度下降的結果來替
6 Experiments
6.1 Few-Shot Classification
我們在兩個流行的few-shot分類任務上評估了我們的方法:Omniglot[11]和Mini-ImageNet[18]。這些數據集使我們的方法容易與其他few-shot的學習方法,如MAML,相比較。
在few-shot分類任務中,我們有一個包含許多類C的元數據集D,其中每個類C本身是一組示例實例{c1, c2,…,cn}。如果我們做的是K-shot, N-way分類,那么我們通過從總類C中選N個類,然后為每個類選擇K + 1個例子來采樣任務。我們將這些示例分割為一個訓練集和一個測試集,其中測試集包含每個類的單個示例。模型可以看到整個訓練集,然后它必須能夠分類從測試集中隨機選擇的樣本。例如,如果你訓練模型用於5-shot,5-way分類,然后你將給模型25個樣本(每類5個樣本,有5個類),並讓它分類第26個樣本。
除了上面的設置之外,我們還嘗試了傳導(transductive)設置,其中模型一次對整個測試集進行分類。在我們的傳導實驗中,信息通過batch normalization[9]在測試樣品之間共享。在我們的非傳導(non-transductive)實驗中,batch normalization統計使用所有的訓練樣本和單一的測試樣本來計算。我們注意到Finn等人[4]使用傳導來評估MAML。
在我們的實驗中,我們使用與Finn等[4]相同的CNN架構和數據預處理。在整個實驗中,我們在內循環中使用Adam optimizer[10],在外循環中使用vanilla SGD。對於Adam,我們將其設置為β1 = 0,因為我們發現momentum會全面降低性能。在訓練過程中,我們沒有對Adam的滾動矩數據進行重置或插值;相反,我們讓它在每個內循環訓練步驟中自動更新。但是,在評估測試集時,我們備份並重置了Adam統計數據,以避免信息泄漏。
在Omniglot和Mini-ImageNet上的結果如表1和表2所示。雖然MAML、FOMAML和Reptile在所有這些任務上都具有非常相似的性能,但Reptile在Mini-ImageNet上的性能略好於替代方案,在Omniglot上的性能略差。傳導似乎在所有情況下都能提高性能,這表明進一步的研究應該密切關注在測試過程中對batch normalization的使用。
6.2 Comparing Different Inner-Loop Gradient Combinations
在本實驗中,我們在每個內循環中使用4個不重疊的mini-batches,產生梯度g1、g2、g3和g4。然后,我們比較了使用不同的gi的線性組合進行外循環更新時的學習性能。注意,兩步Reptile對應g1 + g2,兩步FOMAML對應g2。
為了更容易地比較不同的線性組合,我們用幾種方法簡化了實驗設置。首先,我們在內部和外部循環中都使用vanilla SGD。其次,我們沒有使用meta-batches。第三,我們把實驗限制在5-shot 5-way的Omniglot。通過這些簡化,我們不必過多地擔心超參數或優化器的影響。
圖3顯示了各種內循環梯度組合的學習曲線。對於一個以上項的梯度組合,我們對內部梯度進行求和和平均來校正有效步長增加。
正如預期的那樣,只使用第一個梯度g1是相當無效的,因為它等於優化所有任務的預期損失。令人驚訝的是,兩步Reptile(即g1 + g2,綠色)明顯比兩步FOMAML(即g2,紅色)更糟糕,這可能是因為兩步Reptile中對比AvgGrad,給AvgGradInner的權重更少(公式(34)和(35))。最重要的是,所有的方法都會隨着mini-batches數量的增加而改進。當使用所有梯度的累加(爬行類)而不是只使用最終梯度(FOMAML)時,這種改進更為顯著。這也表明Reptile可以從執行許多內部循環步驟中獲益,這與6.1節中找到的最佳超參數一致。
6.3 Overlap Between Inner-Loop Mini-Batches
Reptile和FOMAML都在內部循環中使用隨機優化。對這個優化過程的微小更改可能導致最終性能的大變化。本節探討了Reptile和FOMAML對內部循環超參數的敏感性,還顯示了如果mini-batches選擇錯誤,FOMAML的性能會顯著下降。
本節中的實驗將研究shared-tail FOMAML —— 其最終的內循環mini-batch與早期的內循環batches來自同一個數據集, 和seperate-tail FOMAML —— 其最后的mini-batch來自一組不相關的數據,這兩個FOMAML之間的差異。將seperate-tail FOMAML看作是一種近似於MAML的方法,可以認為是更正確的方法(Finn等人[4]使用了它),因為訓練時間優化類似於測試時間優化(測試集與訓練集不重疊)。事實上,我們發現seperate-tail FOMAML明顯優於shared-tail FOMAML。正如我們將展示的,當用於計算元梯度(gFOMAML = gk)的數據與之前的批次顯著重疊時,shared-tail FOMAML的性能會下降;然而,Reptile和seperate-tail FOMAML能維護性能,並且對內部循環超參數不是很敏感。
圖4a顯示,當通過訓練數據循環(shared-tail,cycle)選擇minibatches時,shared-tail FOMAML最多執行4次內循環迭代,但在5次迭代時性能下降,其中最終的minibatch(用於計算gFOMAML = gk)與之前的minibatches重疊。當我們使用隨機抽樣方法代替循環選取方法(shared-tail,replacement)時,shared-tail FOMAML退化得更緩慢。我們推測這是因為在最后一批中仍然出現了一些之前沒有出現的樣品。效果是隨機的,所以曲線更平滑是有道理的。
圖4b顯示了類似的現象,但是這里我們將內部循環固定為4次迭代,並改變batch大小。對於大於25的batch size,shared-tail FOMAML的最后一個內部循環batch必須包含以前batches的樣本。與圖4a相似,在這里我們觀察到,隨機抽樣下的shared-tail FOMAML比循環下的shared-tail FOMAML退化更緩慢。
在這兩種參數掃描中,隨着內循環迭代次數或batch size的變化,seperate-tail FOMAML和Reptile的性能不會下降。
對於上述發現有幾種可能的解釋。例如,我們可以假設,在這些實驗中,shared-tail FOMAML的效果更差,只是因為它的有效步長遠低於seperate-tail FOMAML。然而,圖4c表明情況並非如此:在一次徹底掃描中,對於每一個步長選擇,性能都同樣糟糕。另一種假設是,shared-tail FOMAML表現不佳的原因是,在一個樣本上經過幾個內循環步驟后,該樣本的損失梯度並不包含關於該樣本的非常有用的信息。換句話說,最初的幾個SGD步驟可能會使模型接近局部最優,然后進一步的SGD步驟可能只是在這個局部最優附近反彈。
7 Discussion
在測試時執行梯度下降的元學習算法由於其簡單性和泛化特性[5]很具有吸引力。微調的有效性(例如,在ImageNet[2]上訓練的模型)給了我們對這些方法更多的信心。本文提出了一種新的算法——Reptile,其訓練過程與聯合訓練只有細微的不同,只使用一階梯度信息(如一階MAML)。
對於Reptile的工作原理,我們給出了兩個理論解釋。首先,通過用泰勒級數近似更新,我們證明了SGD自動給出了MAML計算的同樣類型的二階項。這個項調整初始權值以最大化同一任務上不同小批量的梯度之間的點積,即它鼓勵在同一任務的小批量之間泛化梯度。我們還提供了第二個非正式的論點,即Reptile找到了一個接近(歐氏距離)所有訓練任務的最優解manifold的點。
雖然本文研究的是元學習設置,但第5.1節中的泰勒級數分析在一般情況下可能對隨機梯度下降有一定的影響。這表明,在進行隨機梯度下降時,我們會自動執行類似MAML的更新,從而最大化不同小批量之間的泛化。這個觀察結果部分地解釋了為什么微調(例如,從ImageNet到更小的數據集[20])效果很好。這一假設表明,聯合訓練加上微調將繼續成為元學習在各種機器學習問題上的強大基礎。
8 Future Work
我們看到了未來工作的幾個有希望的方向:
- 理解SGD在多大程度上自動優化泛化,以及這種效果是否能在非元學習設置中被放大。
- 在強化學習設置中應用Reptile。到目前為止,我們得到了消極的結果,因為聯合訓練是一個強大的基線,所以Reptile的一些修改可能是必要的。
- 探索是否可以通過更深層次的分類器架構來提高Reptile的few-shot學習性能。
- 探索正則化是否可以提高few-shot學習性能,因為目前訓練和測試錯誤之間存在很大的差距。
- 評估Reptile在[14]的few-shot密度建模任務的效果。