【筆記】MAML-模型無關元學習算法


論文信息:

Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135.

一、摘要

  • 元學習的目標是在各種學習任務上訓練一個模型,這樣它就可以使用少量的訓練樣本來解決新的學習任務。

  • 本文提出了一種與模型無關的元學習算法,它適用於任何基於梯度下降進行訓練的模型,並且適用於各種學習問題,如分類(Classification)、回歸(Regression)和強化學習(Reinforcement Learning)

  • 在本文提出的方法中,模型的參數被顯式地訓練,模型在處理新任務時,只需幾次的梯度更新以及少量的訓練數據就能取得良好的泛化性能。

  • 該方法在兩種few-shot圖像分類基准(Omniglot和 MiniImagenet)上取得了較好的性能,在few-shot回歸上取得了較好的效果,並利用神經網絡策略加速了策略梯度強化學習的微調。

二、背景

  • 顯式訓練與隱式訓練

    參考顯函數與隱函數的概念:

    • 隱函數:能確定y與x之間關系的方程,F(x,y)=0。x與y混雜在一起。有些隱函數可顯化為顯函數。
    • 顯函數:用y=f(x)表示的函數。x與y明顯區分。
    • 函數是方程,方程不一定是函數。因為函數需要實現一個數域到另一個數域的映射,而方程只要是含有未知數的等式即可。

    這樣模型參數的顯式訓練與隱式訓練就可以理解為因果區分與因果混雜的情況。

    • 隱式訓練:沒有明確的表達式來對目標參數進行更新。
    • 顯式訓練:存在明確的表達式來更新目標參數。
  • 參數方法與非參數方法

    • 參數方法(parametric method):根據先驗知識假定模型服從某種分布,然后利用訓練集估計出模型參數。這種方法中模型的參數固定,不隨數據點的變化而變化。
    • 非參數方法(parametric method):基於記憶訓練集,在預測新樣本值時每次都會重新訓練數據,得到新的參數值。參數的數目隨數據點的變化而變化。
  • Hessian Matrix(海森矩陣)

    • 海塞矩陣(Hessian Matrix),又譯作海森矩陣,是一個多元函數的二階偏導數構成的方陣。

    • 處理一元函數極值問題,如\(f(x)=x^2\) ,我們會先求一階導數,即 \(f^{\prime}(x)=2x\) ,然后根據費馬定理——極值點處的一階導數一定等於 0。但這僅是一個必要條件,而非充分條件。如 \(f(x)=x^3\),顯然只檢查一階導數是不足以下定論的。所以進行二次求導,得出以下規律:

      • 如果一階導數\(f^{\prime}(x)=0\) 且二階導數\(f^{\prime \prime}(x_0)<0\) ,則\(f(x)\) 在此點處取得局部極大值;
      • 如果一階導數\(f^{\prime}(x)=0\) 且二階導數\(f^{\prime \prime}(x_0)>0\) ,則\(f(x)\) 在此點處取得局部極小值;
      • 如果一階導數\(f^{\prime}(x)=0\) 且二階導數\(f^{\prime \prime}(x_0)=0\) ,則無法確定
    • 處理多元函數極值問題,則需要首先對每個變量求偏導,令其為零,定位極值點的可能位置,然后利用二階導數判斷是極大值還是極小值。\(n\) 元函數有 \(n^2\) 個二階導數,因此構成海森矩陣

      \[ \mathbf{H}=\begin{bmatrix} \frac{\partial^2f}{\partial x_1^2} & \frac{\partial^2f}{\partial x_1\partial x_2} & \cdots &\frac{\partial^2f}{\partial x_1\partial x_n} \\ \frac{\partial^2f}{\partial x_2\partial x_1} & \frac{\partial^2f}{\partial x_2^2} & \cdots &\frac{\partial^2f}{\partial x_2\partial x_n} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial^2f}{\partial x_n\partial x_1}&\frac{\partial^2f}{\partial x_n\partial x_2}&\cdots &\frac{\partial^2f}{\partial x_n^2} \end{bmatrix} \]

      • 海森矩陣的極值判斷階段如下:
        • 如果是正定矩陣,則臨界點處是一個局部極小值
        • 如果是負定矩陣,則臨界點處是一個局部極大值
        • 如果是不定矩陣,則臨界點處不是極值
  • 元學習問題引入

    • 元學習過程實際上是一個創造一個高級代理的過程,這個代理在處理新任務的新數據時,能將先驗知識整合進來並且能避免過擬合,即在不同任務之間具備泛化能力。

      高級代理可以理解為創造模型的模型,或者是一組模型參數,它能夠根據不同的任務生成不同的模型參數,這套模型參數能夠在新任務給定的新數據上快速的學習,適應任務的需要。

    • 為了得到具有快速適應能力的模型,元學習訓練一般以Few-Shot Learning(少樣本學習)的形式進行。

      • Few-Shot,可以分為1~k shot,即在訓練過程中提供給模型1~k個樣本數據,讓模型進行學習。

      • 注意與Small Sample Learning(SSL,小樣本學習)進行區分。后者的范圍比前者更加廣泛,具體參見Small Sample Learning in Big Data Era

    • 通過少量的樣本數據構建成一個任務,然后讓元學習模型在許多依此法創建的任務上進行訓練學習,這樣,經過訓練的元學習模型就能憑借少量的數據和幾次的訓練快速適應新的任務了。

      實際上,元學習模型的訓練過程是以一整個一整個的任務作為”訓練數據樣本“的。

  • 元學習問題的公式化表達

    • 概念定義

      • 定義一個模型,用\(f\)表示。模型\(f\)能實現觀察值\(x\)到輸出值\(a\)的映射。

      • 定義單個任務\(T\)

        \[\mathcal{T= \left\{ L(\mathrm{x_1,a_1,\dots,x_H,a_H}),q(\mathrm{x_1}),q(\mathrm{x_{t+1}|x_t,a_t}),\mathrm{H}\right\}} \]

        • \(\mathcal{L}\)表示損失函數,\(\mathcal{L(\mathrm{x_1,a_1,\dots,x_H,a_H})}\rightarrow \mathbb{R}\)
        • \(\mathcal{q}(\mathrm{x_1})\)表示初始觀測變量的分布。
        • \(\mathcal{q(\mathrm{x_{t+1}|x_t,a_t})}\)表示轉移分布
        • \(\mathrm{H}\)表示跨度(Episode Length),對於i.i.d(獨立同分布)監督學習問題,H=1。
      • 期望模型適應的任務的分布\(p(\mathcal{T})\)

    • 學習過程

      • 初始化:隨機初始化元學習模型參數\(\theta\),各子任務模型的初始化參數是對\(\theta\)的拷貝。

      • 元訓練

        1. \(p(\mathcal{T})\)中抽取任務\(\mathcal{T_i}\)
        2. \(\mathcal{q(i)}\)中抽取\(\mathrm{K}\)個樣本;
        3. 用這\(\mathrm{K}\)個樣本對任務\(\mathcal{T_i}\)進行訓練,得到相應的損失\(\mathcal{L_{T_i}}\),並對該任務的模型參數進行梯度更新;
        4. 在新的數據樣本上測試更新后的網絡,得到錯誤情況。
      • 元測試

        1. 根據各個任務更新后的網絡的表現(test error)求初始化參數的梯度,並對元學習模型的參數其進行更新;
        2. 測試其在元測試集任務上的表現,即為元學習模型的最終表現。

三、介紹

  • 本文提出的MAML算法的關鍵思想:訓練模型的初始化參數,使模型能在來自新任務的少量數據上對參數執行數次(1~多次)的梯度更新后能得到最佳的表現。

    • 特征學習的角度理解——MAML算法試圖建立一種模型的內部表示,這種內部表示廣泛適用於許多任務。這樣在處理新的任務時,只需對模型參數進行簡單的微調就能產生較好的結果。

    • 動態系統的角度理解——MAML的學習過程就是要讓新任務的損失函數對於參數的敏感度最大化。當具有較高的敏感度時,參數的微小的局部變化就可以導致任務損失的巨大提升。

      動態系統:若系統在t0時刻的響應y(t0),不僅與t0時刻作用於系統的激勵有關,而且與區間(-∞,t0)內作用於系統的激勵有關,這樣的系統稱為動態系統。

  • 本文的主要貢獻包括以下幾個方面:

    1. 提出了一種元學習的簡單模型以及與任務無關的算法,通過訓練模型參數,使得模型參數只要經過少量次數的梯度更新就能實現在新任務上的快速學習。
    2. 在不同的模型,如全連接和卷積網絡,以及不同領域上,如少樣本回歸、圖片分類和增強學習上驗證了本文提出的算法。
    3. 本文提出的方法通過使用少量參數,能夠與目前最先進的專門用於監督分類的one-shot 學習算法媲美,並且能夠應用於回歸任務和加速任務可變情況下的強化學習過程。

四、實現

  • MAML算法的實現直覺(Intuition)是模型的某些內部表示更容易在不同的任務之間轉換。比如存在某種內部表示能夠適用於任務分布\(\mathcal{p(T)}\)中的所有任務而不是某一個具體的任務。由於最終模型會在新任務上使用基於梯度下降的學習規則進行微調,所以可以以一種顯式的方式去學習一個具備這種規則的模型。

    這種待學習的規則可以理解為一組對任務變化敏感的模型參數,當參數沿着任務的損失梯度方向變化時,可以使得任務損失得到較大的改善。

  • 原理圖如下:

    MAML原理圖

    • \(\theta\) 是已經優化過的模型參數表示。
    • \(\theta\) 沿着新任務損失梯度方向變化時,會使得任務損失大幅改善,從而得到對於新任務的最佳模型參數 \(\theta^{\star}\)
  • 算法描述:

    MAML算法

    • 模型由函數 \(f_{\theta}\) 表示,該函數由參數 \(\theta\) 決定。

    • 整個算法分為兩個循環:

      • 兩者共享模型參數 \(\theta\)
      • 兩者的梯度更新的學習率分別由超參數 \(\alpha\)\(\beta\) 表示
      • 內循環計算各子任務的損失 \(\mathcal{L_{T_{i}}}\) 和進行一至多次梯度更新后的參數 \(\theta^{'}_{i}\)
      • 外循環根據內循環的優化參數在新任務上重新計算損失,並計算其對初始參數的梯度,然后對初始參數進行SGD梯度更新。
      • 重復內外循環,就可以得到元學習模型對於任務分布$ \mathcal{p(T)}$的最佳參數
    • 注意

      • 擁有“最佳參數”的模型在處理新任務時,由於具備了先驗知識,所以只需進行微調就能產生較好的效果。
      • 外循環又稱之為元優化(meta-optimization)
      • 為了適應不同的任務,內循環中的模型參數會演化成 \(\theta^{\prime}\)。而外循環中模型參數需要等到內循環中的所有任務的模型參數都優化后再進行更新。
      • 由於存在一個嵌套關系,外層的梯度更新依賴內層的梯度,因此就會出現二階導數(梯度的梯度)的計算,需要使用到海森向量積(Hessian-Vector Product)
      • 在論文中,作者提出了一種近似算法,利用一階梯度近似代替二階梯度,形成FOMAML(First-Order MAML)算法,具體公式推導過程,見MAML講解-李弘毅
  • 算法擴展:

    • 監督學習(Supervised Learning):算法中的公式(2)和公式(3)分別指代下面的兩個損失函數。

      • 分類(Classification)任務的損失函數采用交叉熵(cross entropy)

        \[\mathcal{L_{T_i}(f_\phi)}=\sum_{x^{(j)},y^{(j)}\sim \mathcal{T_i}}y^{(j)}\log f_{\phi}(x^{(j)})+(1-y^{(j)})\log(1-f_{\phi}(x^{(j)})) \]

      • 回歸(Regression)任務的損失函數采用均方差(mean-squared error)

        \[\mathcal{L_{T_i}(f_\phi)}=\sum_{x^{(j)},y^{(j)}\sim \mathcal{T_i}}\begin{Vmatrix} f_{\phi}(x^{(j)})-y^{(j)}\end{Vmatrix}_2^2 \]

    • 強化學習(Reinforcement Learning):算法中的公式(4)指代下面的損失函數。

      • 強化學習損失函數

        • 強化學習過程基於馬爾可夫決策過程(Markov Decision Porcess)。
        • 具體細節還未深入了解,待補充……

        \[\mathcal{L_{T_i}(f_\phi)}=-\mathbb{E}_\mathcal{x_t,a_t\sim f_\phi,q_{T_i}}[\sum_{t=1}^H R_i(x_t,a_t)] \]

五、實驗

實驗代碼:

  • 回歸(正弦曲線)

    • 通過將MAML算法模型與預訓練模型比較,分別提供K=5和K=10個樣本數據,進行回歸擬合。可以看到:

      • 在沒有提供任何數據點的情況下,MAML由於已經學習到了正弦波的周期結構,所以能夠對曲線進行一定的評估;
      • 對於預訓練模型,由於輸出與已學習到的先驗知識沖突,導致模型無法找到一個合適的表示形式,從而無法通過少量的樣本進行擬合推斷。
    • 比較MAML和預訓練模型的學習曲線可以得出:

      • MAML算法通過少量次數的梯度更新就能實現較低的錯誤率,沒有對少量的數據點過擬合,達到收斂。
      • 預訓練模型則由於缺乏泛化能力,對與少量數據點,很容易過擬合。
  • 分類

    • 通過將MAML模型以及簡化后的FOMAML模型與用於Few-Shot Learning 分類的主流模型在Omiglot和MiniImagenet數據集上比較,可以發現:

      • MAML無視數據集差異、數據點多少以及網絡結構差異,都有優異的表現。

      • FOMAML模型的表現與MAML的表現非常接近,但是兩者的計算消耗卻不同,FOMAML的計算復雜度要明顯低於MAML,這一點也是值得進一步研究的問題。

        對此,作者推測在大多數情況下,損失函數的二階導數非常接近零,因而對模型表現沒有產生太大的影響。

        On First-Order Meta-Learning Algorithms一文中,作者用泰勒公式,對導數進行了展開分析,揭露了深層次的原因。

  • 強化學習

六、總結

  • 提出了一種不引入任何學習參數(實際上增加了學習率\(\alpha 和 \beta\))的通過梯度下降學習模型參數的元學習方法。
  • MAML可以以與任何適合於基於梯度的訓練的模型表示,以及任何可微分的目標(包括分類、回歸和強化學習)相結合。
  • MAML只產生一個權值初始化,所以可以使用任意數量的數據和任意數量的梯度步長來執行自適應。
  • MAML可以使用策略梯度和非常有限的經驗來適應RL代理。
  • 重用來自過去任務的知識可能是構建高容量可伸縮模型(如深度神經網絡)的一個關鍵因素,該模型能夠使用小數據集進行快速訓練。
  • 這種元學習技術可以應用於任何問題和任何模型,可以使多任務初始化成為深度學習和強化學習的標准組成部分。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM