meta learning - Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks - 1 - 論文學習


 代碼:

github.com/cbfinn/maml

github.com/cbfinn/maml_rl

 

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

Abstract

我們提出了一種與模型無關的元學習算法,因為它與任何用梯度下降訓練過的模型兼容,適用於各種不同的學習問題,包括分類、回歸和強化學習。元學習的目標是在各種學習任務上訓練一個模型,這樣它就可以用少量的訓練樣本解決新的學習任務。在我們的方法中,模型的參數被明確地訓練,這樣一個新的任務只用少量的梯度步長和少量的訓練數據就能產生良好的泛化性能。實際上,我們的方法使模型易於微調。我們證明了該方法在兩個few-shot圖像分類(即每類的圖像比較少)基准上取得了最先進的性能,在few-shot回歸上產生了良好的結果,並加速了神經網絡策略梯度增強學習的微調。

 

1. Introduction

快速學習是人類智能的一個標志,無論是從幾個例子中識別物體,還是在短短幾分鍾的經歷后快速學習新技能。我們的人工代理應該能夠做同樣的事情,即從少數例子中快速學習和適應,並隨着越來越多的數據可用而能夠繼續適應。這種快速而靈活的學習具有挑戰性,因為代理必須將其以往的經驗與少量的新信息相結合,同時避免對新數據的過擬合。此外,以往經驗和新數據的形式將取決於任務。因此,為了獲得最大的適用性,學習(或元學習)的機制應該對任務和完成任務所需的計算形式具有普遍性。

在這項工作中,我們提出了一個通用的和模型無關的元學習算法,因為它可以直接應用於任何學習問題和用梯度下降進程訓練的模型。我們的重點是深度神經網絡模型,但我們說明了我們的方法是如何以最小的修改來輕松處理不同的架構和不同的問題設置的,包括分類,回歸,和政策梯度強化學習。在元學習中,經過訓練的模型的目標是從少量的新數據中快速學習新的任務,經過元學習者的訓練,模型能夠對大量不同的任務進行學習。我們的方法的關鍵思想是訓練模型的初始參數,這樣當參數使用新任務的少量數據計算的一個或多個梯度步驟更新后,模型在新任務上有最優的性能。不同於之前學習更新函數或學習規則的元學習方法(Schmidhuber, 1987; Bengio et al.1992; Andrychowicz et al.2016; Ravi & Larochelle, 2017),我們的算法不擴大學習參數的數量,在模型架構上也沒有設置限制(例如,要求周期模型(Santoro et al ., 2016)或Siamese網絡(Koch, 2015)),它很容易與全連接層,卷積或循環神經網絡結合。它也可以用於各種損失函數,包括可微監督損失和不可微強化學習目標。

從特征學習的角度來看,訓練一個模型的參數以使幾個梯度步驟,甚至是單個梯度步驟在一個新任務上產生好的結果的過程,可以看作是建立一個廣泛適用於許多任務的內部表征。如果內部表征適用於許多任務,只需稍微微調參數(例如,在前饋模型中主要修改頂層權重)就可以產生良好的結果。實際上,我們的程序優化了易於快速調整的模型,允許適應發生在快速學習的正確空間。從一個動態系統的觀點來看,我們的學習過程可以被看作是最大化新任務的損失函數對參數的敏感性:當敏感性很高時,對參數的微小局部變化可以導致任務損失的巨大改善。

這項工作的主要貢獻是一個用來訓練模型的參數的簡單的模型和任務不確定的元學習算法,這樣少量的梯度更新將導致對新任務的快速學習。我們演示了不同模型類型的算法,包括全連接和卷積網絡,並在幾個不同的領域應用,包括few-shot回歸,圖像分類,和強化學習。我們評估表明,元學習算法與最先進的專門用於監督分類的one-shot學習方法對比,其使用了更少的參數,也可以應用於回歸並且在任務變化時可以加速強化學習的任務可變性,大大優於作為初始化直接預訓練。

 

2. Model-Agnostic Meta-Learning

我們的目標是訓練能夠實現快速適應的模型,這是一種經常被形式化為few-shot學習的問題設置。在本節中,我們將定義問題設置並給出算法的一般形式。

2.1. Meta-Learning Problem Set-Up

少量元學習的目標是訓練一個模型,使其能夠使用少量的數據點和訓練迭代數來快速適應新的任務。為此,模型或學習者在元學習階段對一組任務進行訓練,這樣經過訓練的模型只需使用少量的例子或試驗就能快速適應新的任務。實際上,元學習問題把整個任務當作訓練樣本。在本節中,我們以一般的方式將元學習問題設置形式化,包括不同學習領域的簡單例子。我們將在第3節詳細討論兩個不同的學習領域。

我們考慮了一個模型, 記為f,它將觀察值x映射到輸出值a。在元學習過程中,該模型被訓練成能夠適應大量或無限數量的任務。由於我們想將我們的框架應用於各種各樣的學習問題,從分類到強化學習,我們在下面介紹一個學習任務的一般概念。形式上來說,每個任務包含一個損失函數L、一個初始觀測值q(X1)的分布、一個轉換分布q(Xt+1 |Xt, at)以及一個eisode長度H。在監督學習問題中,長度H=1。模型可能在每個時間t通過選擇輸出at生成長度H的樣本。該損失提供特定任務的反饋,其可能是誤分類損失或馬爾科夫決策過程中的代價函數的形式。

在元學習的場景中,我們考慮一個我們希望我們的模型能夠適應的任務的分布(即這里面有多個不同的任務,用來訓練模型)。在K-shot學習設置中,模型被訓練去學習一個來自任務分布的新任務,該新任務中的訓練數據僅有K個來自qi的樣本,其反饋生成。在元訓練期間,一個任務中采樣得到,模型使用K個樣本訓練,並從來自的對應損失中得到反饋,然后在來自的新樣本(即非前面使用的K個樣本的別的樣本)中進行測試。然后通過考慮在來自qi的新數據上的測試error如何根據參數進行改變來改善模型f。實際上,在采樣的任務上的測試error將作為元學習過程中的訓練error。在元訓練的結尾,從中采樣新任務,從K個樣本中學習后,元性能將用模型的性能來測量。通常,在元訓練期間,用於元測試的任務會被擱置(即在訓練時不會被使用)。

 

2.2. A Model-Agnostic Meta-Learning Algorithm

與之前的研究相反,之前的研究試圖訓練能攝取整個數據集的遞歸神經網絡(Santoro et al., 2016;Duan et al.,2016b)或特征嵌入,在測試時可與非參數方法結合(Vinyals et al.,2016;(Koch, 2015),我們提出了一種方法,可以通過元學習學習任何標准模型的參數,從而為模型的快速適應做好准備。這種方法背后的直覺是一些內部表征比其他的更容易轉移。例如,神經網絡可以學習廣泛適用於中所有任務的內部特征,而不是單個任務。我們如何鼓勵這種通用表征的出現?我們對這個問題采取一個明確的方法:由於該模型將在一個新的任務上使用一個基於梯度的學習規則來進行微調,我們將致力於以基於這種梯度學習規則可以快速優化來自的新任務的這樣一種方式來學習模型,並沒有過擬合的出現。實際上,我們的目標是找到對任務中變化敏感的模型參數,這樣當變化的方向在損失的梯度方向上時,小的參數的變化將在來自的任何任務的損失函數上產生大的改進(參見圖1)。

 

 

我們沒有對模型的形式做任何假設,除了假設它是由一些參數向量Θ參數化的,並且損失函數在Θ中是足夠平滑的,這樣我們可以使用基於梯度的學習技術。

形式上,我們考慮使用一個帶有參數Θ的參數化函數來表示模型。當適應了一個新任務,模型的參數將更新為。在我們的方法中,更新的參數向量是使用一個或多個在任務上的梯度下降更新方法計算得到的。比如,當使用一個梯度更新時如下式子所示:

 

步長α可以被固定為一個超參數或者是可進行元學習得到。為了簡化概念,我們在下面的section中都將只考慮一次梯度更新,但是使用多個梯度更新時一種最直接的擴展方式

 

模型參數通過優化的性能來訓練,其中Θ遍及在從中采樣得到的任務中。更准確的是,元目標函數表示如下: 

 

請注意,元優化是在模型參數Θ中執行的,而目標函數是使用更新的模型參數進行計算的。實際上,我們所提出的方法旨在優化模型參數,使新任務上的一個或少量梯度步驟將在該任務上產生最大的有效效果。

任務的元優化是通過隨機梯度下降法(SGD)實現的,模型參數Θ更新如下所示:

 

其中表示β步長。整個算法被概述在算法1中:

 

所以整個的訓練過程就是:

1.首先隨機初始化整個模型的參數Θ

2.然后從任務分布隨機選取batch_size個任務

3.然后循環訓練這batch_size個任務訓練數據

  • 首先訓練第一個(假設是貓狗分類任務),其中有每個類有K個樣本(所以叫做K-shot,即貓和狗的訓練數據都只有K張圖片)
  • 使用這些數據去訓練模型,然后使用梯度下降方法去更新初始參數Θ,得到Θ'1
  • 接着再循環回3去訓練另一個任務,再更新參數Θ,得到Θ'2,以此往復得到(Θ'3,..., Θ'batch_size_of_tasks)直至batch_size個任務都訓練完

4.最使用訓練中對每個任務訓練得到的參數(Θ'1,..., Θ'batch_size_of_tasks)作為模型的參數,輸入測試任務(測試數據)去計算損失,損失求和再進行梯度下降來最終更新全局的初始化參數Θ

5.然后再循環至2根據設置迭代次數再繼續訓練

 

MAML元梯度更新涉及到一個貫穿梯度的梯度。在計算上,這需要額外的通過f的后向傳播來計算Hessian-vector乘積,這一點得到了諸如TensorFlow等標准深度學習庫的支持(Abadi等,2016)。在我們的實驗中,我們還包括了一個與丟棄這個后向傳播和使用一個一階近似的操作比較,我們將在5.2節中討論。

 

3. Species of MAML

在本節中,我們將討論用於監督學習和強化學習的元學習算法的具體實例。這些域在損失函數的形式以及數據如何由任務生成並呈現給模型方面有所不同,但在這兩種情況下可以應用相同的基本適應機制。

 

3.1. Supervised Regression and Classification

few-shot學習在監督任務領域得到了充分的研究,其目標是通過使用該任務的少量輸入/輸出對來學習一個新函數,並使用來自類似任務的先前數據進行元學習。例如,目標可能是使用的是之前已經看到過許多其他類型對象的模型,在只看到一個或幾個Segway示例之后對Segway的圖像進行分類。同樣地,在few-shot回歸中,目標是在對許多具有類似統計特性的函數進行訓練后,從該函數采樣的少數數據點預測連續值函數的輸出。

為了在2.1節中的元學習定義上下文中公式化回歸和分類問題,我們可以定義horizon H = 1,並下放時間下標如xt, 因為模型接受一個輸入,產生一個輸出,而非輸入和輸出序列。任務從qi中生成K個觀測值x,任務損失用模型對x的輸出與該觀測和任務對應的目標值y之間的誤差表示。

用於監督分類和回歸的兩個常見損失函數是交叉熵和均方誤差(MSE),我們將在下面描述;不過,也可以使用其他監督損失函數。對於使用均方誤差的回歸任務,損失形式為:

其中表示從任務中采樣的輸入/輸出對。在K-shot回歸任務中,為每個任務提供K個輸入/輸出對用於學習

 

同樣的,對於離散的帶有交叉熵損失的分類任務,其損失如下所示: 

根據往常的術語,K-shot分類任務在每個類中使用K個輸入/輸出對,對於N個類的分類任務,總共需要NK個數據點。給定任務的分布,這些損失函數能夠被直接插入到Section 2.2的等式中去實現元學習,如下面的算法2所示:

 

 

 

3.2. Reinforcement Learning

在強化學習(RL)中,few-shot元學習的目標是使agent僅使用測試設置中的少量經驗就能快速獲得一個新測試任務的策略。一項新任務可能涉及實現一個新目標,或者在一個新的環境中實現一個之前訓練過的目標。例如,一個代理可能學會快速找出如何在迷宮中導航,這樣,當面對一個新的迷宮時,它就可以通過少量樣本確定如何可靠地到達出口。在本節中,我們將討論如何將MAML應用於RL的元學習。

 

每一個強化學習任務包含一個初始狀態分布和一個轉換分布,損失與(負)獎勵函數R相關。因此整個網絡是一個帶有Horizon H的馬爾可夫決策過程(MDP),其中學習器被允許去查詢有限數量的樣本軌跡用於few-shot學習。在中,MDP的任何方面都可能在不同任務之間發生變化。正在學習的模型fθ使用的策略是在每個時間step 上將狀態Xt映射到actions at的分布上。任務的損失和模型fΦ的形式如下:

在K-shot強化學習中,K個rollouts來自fθ和任務和相應的獎勵可能會用於適應新任務

由於未知的動態,預期獎勵通常是不可微的,因此我們使用策略梯度方法來估計模型梯度更新和元優化的梯度。因為策略梯度是on-policy的算法,在fθ的適應期間,每個額外的梯度step需要來自當前策略的新樣本。我們在上面的算法3詳細說明了算法。該算法與算法2有着相同的結構,主要的不同在step 5和8,其需要從相關於任務的環境中采樣軌跡。該方法的實際實現可能會使用目前提出的用於策略梯度算法一系列改進,包括狀態或行為獨立的基線和trust regions(Schulman et al., 2015).

 

4. Related Work

忽略

 

5. Experimental Evaluation

...

5.2. Classification

...

在MAML中,當通過元目標中的梯度算子反向傳播元梯度時,使用二階導數會消耗大量的計算量(見方程(1))。在MiniImagenet上,我們與MAML的一階近似進行了比較,其中省略了這些二階導數。值得注意的是,結果方法仍然在后向更新參數值后計算元梯度,用於高效的元學習。令人驚訝的是這兩種方法(即一個省略了二階導數,一個沒省略)的性能幾乎是一樣的,獲得完整的二階導數,這表明大部分的改善MAML來自后向更新參數值時目標函數的梯度,而不是二階更新。過去的研究發現ReLU神經網絡在局部幾乎是線性的(Goodfellow et al., 2015),這表明二階導數在大多數情況下可能接近於零,這部分解釋了一階近似的良好性能。這種近似消除了在額外的向后傳遞中計算Hessian-vector內積的需要,我們發現這樣可以使網絡計算速度提高大約33%。

 

計算步驟像是:

 

詳情可見

https://www.bilibili.com/video/BV1pQ4y1K7cw?p=36

李宏毅機器學習—進階部分

 

然后省略掉二階部分:

 

可見就簡化成了只用考慮i=j的情況:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM