paper:https://link.zhihu.com/?target=https%3A//arxiv.org/pdf/1703.03400.pdf
MAML在學術界已經是非常重要的模型了,論文Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks自2017年發表至今已經收獲了400+的引用。由於當前網上關於MAML的中文介紹少之又少,可能很多小伙伴對其還不是特別理解。所以今天我整理了這段時間來的學習心得,與大家分享自己對MAML的認識與理解。MAML可以用於Supervised Regression and Classification以及Reinforcement Learning。由於我對強化學習不是特別了解,因此這篇文章,均是基於MAML在Supervised Regression and Classification中的運用。
一、一些相關概念的介紹
在原論文中,作者直接引用了許多元學習相關的概念,例如 meta-learning, model-agnostic, N-way K-shot, tasks等等,其中有些概念在MAML中還有特殊的含義。在此,我盡量用通俗易懂的方式對這些概念為大家做一個介紹。
(1) meta-learning
meta-learning即元學習,也可以稱為“learning to learn”。常見的深度學習模型,目的是學習一個用於預測的數學模型。而元學習面向的不是學習的結果,而是學習的過程。其學習的不是一個直接用於預測的數學模型,而是學習“如何更快更好地學習一個數學模型”。
舉一個現實生活的例子。我們教小朋友讀英語時,可以直接讓他們模仿apple、banana的發音。但是他們很快又會遇到新的單詞,例如strawberry,這是小朋友就需要重新聽你的發音,才能正確地讀出這個新單詞。我們換一種方式,這一次我們不教每個單詞的發音,而是教音標的發音。從此小朋友再遇見新單詞,他們只要根據音標,就可以正確地讀出這個單詞。學習音標的過程,正是一個元學習的過程。
在深度學習中,已經被提出的元學習模型有很多,大致上可以分類為learning good weight initializations,meta-models that generate the parameters of other models 以及learning transferable optimizers。其中MAML屬於第一類。MAML學習一個好的初始化權重,從而在新任務上實現fast adaptation,即在小規模的訓練樣本上迅速收斂並完成fine-tune。
(2) model-agnostic
model-agnostic即模型無關。MAML與其說是一個深度學習模型,倒不如說是一個框架,提供一個meta-learner用於訓練base-learner。這里的meta-learner即MAML的精髓所在,用於learning to learn;而base-learner則是在目標數據集上被訓練,並實際用於預測任務的真正的數學模型。絕大多數深度學習模型都可以作為base-learner無縫嵌入MAML中,而MAML甚至可以用於強化學習中,這就是MAML中model-agnostic的含義。
(3) N-way K-shot
N-way K-shot是few-shot learning中常見的實驗設置。few-shot learning指利用很少的被標記數據訓練數學模型的過程,這也正是MAML擅長解決的問題之一。N-way指訓練數據中有N個類別,K-shot指每個類別下有K個被標記數據。
(4) task
MAML的論文中多次出現名詞task,模型的訓練過程都是圍繞task展開的,而作者並沒有給它下一個明確的定義。要正確地理解task,我們需要了解的相關概念包括,
, support set, query set, meta-train classes, meta-test classes等等。是不是有點眼花繚亂?不要着急,舉個簡單的例子,大家就可以很輕松地掌握這些概念。
我們假設這樣一個場景:我們需要利用MAML訓練一個數學模型模型 ,目的是對未知標簽的圖片做分類,類別包括
(每類5個已標注樣本用於訓練。另外每類有15個已標注樣本用於測試)。我們的訓練數據除了
中已標注的樣本外,還包括另外10個類別的圖片
(每類30個已標注樣本),用於幫助訓練元學習模型
。我們的實驗設置為5-way 5-shot。
關於具體的訓練過程,會在下一節MAML算法詳解中介紹。這里我們只需要有一個大概的了解:MAML首先利用 的數據集訓練元模型
,再在
的數據集上精調(fine-tune)得到最終的模型
。
此時,即meta-train classes,
包含的共計300個樣本,即
,是用於訓練
的數據集。與之相對的,
即meta-test classes,
包含的共計100個樣本,即
,是用於訓練和測試
的數據集。
根據5-way 5-shot的實驗設置,我們在訓練 階段,從
中隨機取5個類別,每個類別再隨機取20個已標注樣本,組成一個task
。其中的5個已標注樣本稱為
的support set,另外15個樣本稱為
的query set。這個task
, 就相當於普通深度學習模型訓練過程中的一條訓練數據。那我們肯定要組成一個batch,才能做隨機梯度下降SGD對不對?所以我們反復在訓練數據分布中抽取若干個這樣的task
,組成一個batch。在訓練
階段,task、support set、query set的含義與訓練
階段均相同。
二、MAML算法詳解
作者在論文中給出的算法流程如下:
MAML算法
該算法實質上是MAML預訓練階段的算法,目的是得到模型 。不要被這些數學符號嚇到喔,這個算法的思路其實很簡單。接下來,我們來一行一行地分析這個算法。
首先來看兩個Require。
第一個Require指的是在 中task的分布。結合我們在上一小節舉的例子,這里即反復隨機抽取task
,形成一個由若干個(e.g., 1000個)
組成的task池,作為MAML的訓練集。有的小伙伴可能要納悶了,訓練樣本就這么多,要組合形成那么多的task,豈不是不同task之間會存在樣本的重復?或者某些task的query set會成為其他task的support set?沒錯!就是這樣!我們要記住,MAML的目的,在於fast adaptation,即通過對大量task的學習,獲得足夠強的泛化能力,從而面對新的、從未見過的task時,通過fine-tune就可以快速擬合。task之間,只要存在一定的差異即可。再強調一下,MAML的訓練是基於task的,而這里的每個task就相當於普通深度學習模型訓練過程中的一條訓練數據。
第二個Require就很好理解啦。step size其實就是學習率,讀過MAML論文的小伙伴一定會對gradient by gradient這個詞有印象。MAML是基於二重梯度的,每次迭代包括兩次參數更新的過程,所以有兩個學習率可以調整。
接下來,就是激動人心的算法流程。
步驟1,隨機初始化模型的參數,沒什么好說的,任何模型訓練前都有這一步。
步驟2,是一個循環,可以理解為一輪迭代過程或一個epoch,當然啦預訓練的過程是可以有多個epoch的。
步驟3,相當於pytorch中的DataLoader,即隨機對若干個(e.g., 4個)task進行采樣,形成一個batch。
步驟4~步驟7,是第一次梯度更新的過程。注意這里我們可以理解為copy了一個原模型,計算出新的參數,用在第二輪梯度的計算過程中。我們說過,MAML是gradient by gradient的,有兩次梯度更新的過程。步驟4~7中,利用batch中的每一個task,我們分別對模型的參數進行更新(4個task即更新4次)。注意這一個過程在算法中是可以反復執行多次的,偽代碼沒有體現這一層循環,但是作者再分析的部分明確提到" using multiple gradient updates is a straightforward extension"。
步驟5,即對利用batch中的某一個task中的support set,計算每個參數的梯度。在N-way K-shot的設置下,這里的support set應該有NK個。作者在算法中寫with respect to K examples,默認對每一個class下的K個樣本做計算。實際上參與計算的總計有NK個樣本。這里的loss計算方法,在回歸問題中,就是MSE;在分類問題中,就是cross-entropy。
步驟6,即第一次梯度的更新。
步驟4~步驟7,結束后,MAML完成了第一次梯度更新。接下來我們要做的,是根據第一次梯度更新得到的參數,通過gradient by gradient,計算第二次梯度更新。第二次梯度更新時計算出的梯度,直接通過SGD作用於原模型上,也就是我們的模型真正用於更新其參數的梯度。
步驟8即對應第二次梯度更新的過程。這里的loss計算方法,大致與步驟5相同,但是不同點有兩處。一處是我們不再是分別利用每個task的loss更新梯度,而是像常見的模型訓練過程一樣,計算一個batch的loss總和,對梯度進行隨機梯度下降SGD。另一處是這里參與計算的樣本,是task中的query set,在我們的例子中,即5-way*15=75個樣本,目的是增強模型在task上的泛化能力,避免過擬合support set。步驟8結束后,模型結束在該batch中的訓練,開始回到步驟3,繼續采樣下一個batch。
以上即時MAML預訓練得到 的全部過程,是不是很簡單呢?事實上,MAML正是因為其簡單的思想與驚人的表現,在元學習領域迅速流行了起來。接下來,應該是面對新的task,在
的基礎上,精調得到
的方法。原文中沒有介紹fine-tune的過程,這里我向小伙伴們簡單介紹一下。
fine-tune的過程與預訓練的過程大致相同,不同的地方主要在於以下幾點:
- 步驟1中,fine-tune不用再隨機初始化參數,而是利用訓練好的
初始化參數。
- 步驟3中,fine-tune只需要抽取一個task進行學習,自然也不用形成batch。fine-tune利用這個task的support set訓練模型,利用query set測試模型。實際操作中,我們會在
上隨機抽取許多個task(e.g., 500個),分別微調模型
,並對最后的測試結果進行平均,從而避免極端情況。
- fine-tune沒有步驟8,因為task的query set是用來測試模型的,標簽對模型是未知的。因此fine-tune過程沒有第二次梯度更新,而是直接利用第一次梯度計算的結果更新參數。