Model-Agnostic Meta-Learning (MAML)模型介紹及算法詳解


paper:https://link.zhihu.com/?target=https%3A//arxiv.org/pdf/1703.03400.pdf

 

MAML在學術界已經是非常重要的模型了,論文Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks自2017年發表至今已經收獲了400+的引用。由於當前網上關於MAML的中文介紹少之又少,可能很多小伙伴對其還不是特別理解。所以今天我整理了這段時間來的學習心得,與大家分享自己對MAML的認識與理解。MAML可以用於Supervised Regression and Classification以及Reinforcement Learning。由於我對強化學習不是特別了解,因此這篇文章,均是基於MAML在Supervised Regression and Classification中的運用。

一、一些相關概念的介紹

在原論文中,作者直接引用了許多元學習相關的概念,例如 meta-learning, model-agnostic, N-way K-shot, tasks等等,其中有些概念在MAML中還有特殊的含義。在此,我盡量用通俗易懂的方式對這些概念為大家做一個介紹。

(1) meta-learning

meta-learning即元學習,也可以稱為“learning to learn”。常見的深度學習模型,目的是學習一個用於預測的數學模型。而元學習面向的不是學習的結果,而是學習的過程。其學習的不是一個直接用於預測的數學模型,而是學習“如何更快更好地學習一個數學模型”。

舉一個現實生活的例子。我們教小朋友讀英語時,可以直接讓他們模仿apple、banana的發音。但是他們很快又會遇到新的單詞,例如strawberry,這是小朋友就需要重新聽你的發音,才能正確地讀出這個新單詞。我們換一種方式,這一次我們不教每個單詞的發音,而是教音標的發音。從此小朋友再遇見新單詞,他們只要根據音標,就可以正確地讀出這個單詞。學習音標的過程,正是一個元學習的過程。

在深度學習中,已經被提出的元學習模型有很多,大致上可以分類為learning good weight initializations,meta-models that generate the parameters of other models 以及learning transferable optimizers。其中MAML屬於第一類。MAML學習一個好的初始化權重,從而在新任務上實現fast adaptation,即在小規模的訓練樣本上迅速收斂並完成fine-tune。

(2) model-agnostic

model-agnostic即模型無關。MAML與其說是一個深度學習模型,倒不如說是一個框架,提供一個meta-learner用於訓練base-learner。這里的meta-learner即MAML的精髓所在,用於learning to learn;而base-learner則是在目標數據集上被訓練,並實際用於預測任務的真正的數學模型。絕大多數深度學習模型都可以作為base-learner無縫嵌入MAML中,而MAML甚至可以用於強化學習中,這就是MAML中model-agnostic的含義。

(3) N-way K-shot

N-way K-shot是few-shot learning中常見的實驗設置。few-shot learning指利用很少的被標記數據訓練數學模型的過程,這也正是MAML擅長解決的問題之一。N-way指訓練數據中有N個類別,K-shot指每個類別下有K個被標記數據。

(4) task

MAML的論文中多次出現名詞task,模型的訓練過程都是圍繞task展開的,而作者並沒有給它下一個明確的定義。要正確地理解task,我們需要了解的相關概念包括[公式], [公式] , support set, query set, meta-train classes, meta-test classes等等。是不是有點眼花繚亂?不要着急,舉個簡單的例子,大家就可以很輕松地掌握這些概念。

我們假設這樣一個場景:我們需要利用MAML訓練一個數學模型模型 [公式] ,目的是對未知標簽的圖片做分類,類別包括 [公式] (每類5個已標注樣本用於訓練。另外每類有15個已標注樣本用於測試)。我們的訓練數據除了 [公式] 中已標注的樣本外,還包括另外10個類別的圖片 [公式] (每類30個已標注樣本),用於幫助訓練元學習模型 [公式] 。我們的實驗設置為5-way 5-shot

關於具體的訓練過程,會在下一節MAML算法詳解中介紹。這里我們只需要有一個大概的了解:MAML首先利用 [公式] 的數據集訓練元模型[公式],再在[公式]的數據集上精調(fine-tune)得到最終的模型 [公式]

此時,[公式]meta-train classes[公式]包含的共計300個樣本,即 [公式] ,是用於訓練 [公式] 的數據集。與之相對的,[公式]meta-test classes[公式]包含的共計100個樣本,即 [公式] ,是用於訓練和測試 [公式] 的數據集。

根據5-way 5-shot的實驗設置,我們在訓練 [公式] 階段,從 [公式] 中隨機取5個類別,每個類別再隨機取20個已標注樣本,組成一個task [公式] 。其中的5個已標注樣本稱為 [公式]support set,另外15個樣本稱為 [公式]query set。這個task [公式] , 就相當於普通深度學習模型訓練過程中的一條訓練數據。那我們肯定要組成一個batch,才能做隨機梯度下降SGD對不對?所以我們反復在訓練數據分布中抽取若干個這樣的task [公式] ,組成一個batch。在訓練 [公式] 階段,tasksupport setquery set的含義與訓練 [公式] 階段均相同。

二、MAML算法詳解

作者在論文中給出的算法流程如下:

MAML算法

該算法實質上是MAML預訓練階段的算法,目的是得到模型 [公式] 。不要被這些數學符號嚇到喔,這個算法的思路其實很簡單。接下來,我們來一行一行地分析這個算法。

首先來看兩個Require

第一個Require指的是在 [公式]task的分布。結合我們在上一小節舉的例子,這里即反復隨機抽取task [公式] ,形成一個由若干個(e.g., 1000個) [公式] 組成的task池,作為MAML的訓練集。有的小伙伴可能要納悶了,訓練樣本就這么多,要組合形成那么多的task,豈不是不同task之間會存在樣本的重復?或者某些task的query set會成為其他task的support set?沒錯!就是這樣!我們要記住,MAML的目的,在於fast adaptation,即通過對大量task的學習,獲得足夠強的泛化能力,從而面對新的、從未見過的task時,通過fine-tune就可以快速擬合。task之間,只要存在一定的差異即可。再強調一下,MAML的訓練是基於task的,而這里的每個task就相當於普通深度學習模型訓練過程中的一條訓練數據。

第二個Require就很好理解啦。step size其實就是學習率,讀過MAML論文的小伙伴一定會對gradient by gradient這個詞有印象。MAML是基於二重梯度的,每次迭代包括兩次參數更新的過程,所以有兩個學習率可以調整。

接下來,就是激動人心的算法流程。

步驟1,隨機初始化模型的參數,沒什么好說的,任何模型訓練前都有這一步。

步驟2,是一個循環,可以理解為一輪迭代過程或一個epoch,當然啦預訓練的過程是可以有多個epoch的。

步驟3,相當於pytorch中的DataLoader,即隨機對若干個(e.g., 4個)task進行采樣,形成一個batch。

步驟4~步驟7,是第一次梯度更新的過程。注意這里我們可以理解為copy了一個原模型,計算出新的參數,用在第二輪梯度的計算過程中。我們說過,MAML是gradient by gradient的,有兩次梯度更新的過程。步驟4~7中,利用batch中的每一個task,我們分別對模型的參數進行更新(4個task即更新4次)。注意這一個過程在算法中是可以反復執行多次的,偽代碼沒有體現這一層循環,但是作者再分析的部分明確提到" using multiple gradient updates is a straightforward extension"。

步驟5,即對利用batch中的某一個task中的support set,計算每個參數的梯度。在N-way K-shot的設置下,這里的support set應該有NK個。作者在算法中寫with respect to K examples,默認對每一個class下的K個樣本做計算。實際上參與計算的總計有NK個樣本。這里的loss計算方法,在回歸問題中,就是MSE;在分類問題中,就是cross-entropy

步驟6,即第一次梯度的更新。

步驟4~步驟7,結束后,MAML完成了第一次梯度更新。接下來我們要做的,是根據第一次梯度更新得到的參數,通過gradient by gradient,計算第二次梯度更新。第二次梯度更新時計算出的梯度,直接通過SGD作用於原模型上,也就是我們的模型真正用於更新其參數的梯度。

步驟8即對應第二次梯度更新的過程。這里的loss計算方法,大致與步驟5相同,但是不同點有兩處。一處是我們不再是分別利用每個task的loss更新梯度,而是像常見的模型訓練過程一樣,計算一個batch的loss總和,對梯度進行隨機梯度下降SGD。另一處是這里參與計算的樣本,是task中的query set,在我們的例子中,即5-way*15=75個樣本,目的是增強模型在task上的泛化能力,避免過擬合support set。步驟8結束后,模型結束在該batch中的訓練,開始回到步驟3,繼續采樣下一個batch。

以上即時MAML預訓練得到 [公式] 的全部過程,是不是很簡單呢?事實上,MAML正是因為其簡單的思想與驚人的表現,在元學習領域迅速流行了起來。接下來,應該是面對新的task,在 [公式] 的基礎上,精調得到 [公式] 的方法。原文中沒有介紹fine-tune的過程,這里我向小伙伴們簡單介紹一下。

fine-tune的過程與預訓練的過程大致相同,不同的地方主要在於以下幾點:

  • 步驟1中,fine-tune不用再隨機初始化參數,而是利用訓練好的 [公式] 初始化參數。
  • 步驟3中,fine-tune只需要抽取一個task進行學習,自然也不用形成batch。fine-tune利用這個task的support set訓練模型,利用query set測試模型。實際操作中,我們會在 [公式] 上隨機抽取許多個task(e.g., 500個),分別微調模型 [公式] ,並對最后的測試結果進行平均,從而避免極端情況。
  • fine-tune沒有步驟8,因為task的query set是用來測試模型的,標簽對模型是未知的。因此fine-tune過程沒有第二次梯度更新,而是直接利用第一次梯度計算的結果更新參數。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM