【筆記】元學習專題視頻(台大·李宏毅)學習記錄


一、概述

  • Meta Learning = Learn to learn

    讓機器去學習如何進行學習:使用一系列的任務來訓練模型,模型根據在這些任務上汲取的經驗,成為了一個強大的學習者,能夠更快的學習新任務。

  • Meta Learning VS Lifelong Learning

    • 終身學習:着眼於用同一個模型去學習不同的任務。
    • 元學習:不同任務使用不同的模型,元學習者積累經驗后,在新任務上訓練的更快更好。
  • Meta Learning VS Machine Learning

    • 機器學習:核心是通過人為設計的學習算法(Learning Algorithm),利用訓練數據訓練得到一個函數f,這個函數可以用於新數據的預測分類。

    • 元學習:讓機器自己學習找出最優的學習算法。根據提供的訓練數據找到一個可以找到函數f的函數F的能力。

二、元學習的實現框架

  1. 定義一系列的學習算法

    不同的網絡結構、參數初始化策略、參數更新策略決定了不同學習算法。

  2. 定義學習算法函數F的評價標准

    綜合考慮學習算法F針對不同任務產生的函數f在進行測試時得到的損失。

  3. 選取最好的學習算法F*=argminL(F)

    最佳學習算法一般可以通過梯度下降方法來確定。

三、元學習的訓練數據

  • 機器學習

    機器學習的訓練數據和測試數據來自同一分布的數據集。

  • 元學習:

    • 元學習的訓練數據是由一個個的訓練任務構成的,一個訓練任務對應一個傳統的機器學習的應用實例。

      • 需要大批次數據的訓練任務顯然難以進行元學習訓練,因此常規的元學習的訓練任務一般是Few Shot Learning類型的任務,即通過少量數據就能構建一個任務,進行快速的學習與訓練。
      • 考慮到運算性能,現階段的元學習經常是與Few Shot Learning綁定在一起。
    • 訓練數據分為訓練任務集和測試任務集。

    • 任務集中的每一個任務的訓練數據即傳統的機器學習應用實例中的訓練數據集和測試數據集,不過為了區分訓練任務(Training Set)和測試任務(Testing Test) ,這里將它們命名為支持集(Support Set)和查詢集(Query Set).

四、元學習的Benchmarks

  • Omniglot數據集

    • 組成

      • 整個數據集由1623個符號(Characters)組成;

      • 每個符號有20個樣例(Examples),每個樣例由不同的人書寫.

    • 使用:結合Few-shot Learning中的N-ways K-shot分類問題

      • 對於每一個訓練任務和測試任務,樣本數據分為N個類,每個類提供K個樣本。
      • 整個字符集分為訓練字符集(Training Set or Support Set)和測試字符集(Testing Set or Query Set)
      • 訓練任務:從訓練字符集中抽取N個類的字符,每種字符抽取K個樣本,組成一個訓練任務的訓練數據
      • 測試任務:從測試字符集中抽取N個類的字符,每種字符抽取K個樣本,組成一個測試任務的訓練數據
  • MAML

    Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135.

    • 損失函數 (Loss Function):\(L(\Phi)= \sum_{n=1}^N{l^n(\hat{\theta}^n)}\)

      • \(\hat{\theta}^n\):第n個任務中學習到的模型參數,取決於參數\(\Phi\)
      • \(l^n(\hat{\theta}^n)\):第n個任務在其測試集上得到的損失。
    • 損失函數最小化:使用梯度下降(Gradient Descent)

      \[\Phi\leftarrow\Phi-\eta\nabla_\Phi{L(\Phi)} \]

    • 只考慮一次訓練之后對初始化參數的梯度更新。

      • 只取進行一次梯度更新后的參數作為當前任務的最佳參數。
      • 上式求出的是元學習模型的通用參數,下式求出的是每個任務的最佳參數。
      • \(L(\Phi)\)\(\hat{\theta}\)用於元學習模型的參數更新。
      • 既能加快模型的適應速度,在一定程度上還能減輕過擬合。、

      \[\hat{\theta}=\Phi-\epsilon\nabla_\Phi{l(\Phi)} \]

    • 整體執行流程:

      1. 將每一個訓練任務和測試任務的模型參數初始化:\(\Phi_0\)
      2. 對每一個任務執行一次梯度更新得到新的模型參數:\(\hat{\theta}\)
      3. 綜合考慮所有訓練任務在\(\hat{\theta}\)下的損失:\(L(\Phi)\)
      4. \(L(\Phi)\)執行梯度更新,得到最優的元學習模型的參數:\(\Phi\)
      5. 將該\(\Phi\)用於測試任務,檢驗更新效果。
    • 二階微分與一階近似(數學推導):

      • 訓練過程的參數更新公式如下:

      \[\Phi\leftarrow\Phi-\eta\nabla_\Phi{L(\Phi)} \\ L(\Phi)= \sum_{n=1}^N{l^n(\hat{\theta}^n)} \\ \hat{\theta}=\Phi-\epsilon\nabla_\Phi{l(\Phi)} \]

      • $ \nabla_\Phi{L(\Phi)} $的計算

        \[\nabla_\Phi{L(\Phi)}=\nabla_\Phi{\sum_{n=1}^{N}l^n(\hat{\theta}^n)}=\sum_{n=1}^{N}\nabla_\Phi{l^n(\hat{\theta}^n)} \\ \]

        • 其中\(\nabla_\Phi{l^n(\hat{\theta}^n)}\)為:

          \[\nabla_\Phi{l(\hat{\theta})}=\left| \begin{matrix} \partial l(\hat{\theta})/\partial \Phi_1\\ \partial l(\hat{\theta})/\partial \Phi_2\\ \vdots\\ \partial l(\hat{\theta})/\partial \Phi_i\\ \end{matrix} \right| \]

          \(\Phi_i\)表示模型的各個參數(Weight),\(\Phi_i\)決定當前任務的\(\hat{\theta}\)的第j個參數\(\hat{\theta}_j\),從而影響\(l(\hat{\theta})\)

          • 根據三者之間的關系:\(\Phi_i \rightarrow \hat{\theta}_j \rightarrow l(\hat{\theta})\),有:

            \[\frac{\partial l(\hat{\theta})}{\partial \Phi_i}=\sum_j\sum_i {\frac{\partial l(\hat{\theta})}{\partial \hat{\theta}_j}\frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}} \]

          • 又因為根據參數更新公式(3),取\(\hat{theta}\)的第j維為例,有:

            \[\hat{\theta}_j=\Phi_j-\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j} \]

          • \(\hat{\theta}_j\)\(\Phi_j\)的偏導,有:

            \[\frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}= \begin{cases} -\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i},i \neq j\\ 1-\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i},i = j \end{cases} \]

            將該式代回到\(\frac{\partial l(\hat{\theta})}{\partial \Phi_i}\)中即可求出\(\nabla_\Phi{L(\Phi)}\)

            但實際上該式存在二次微分的計算,會極大的影響運算效率。

          • 作者用一次微分來近似代替二次微分的結果:

            \[\frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}= \begin{cases} -\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i} \approx{0} ,i \neq j\\ 1-\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i} \approx{1},i = j \end{cases} \]

          • 所以

            \[\frac{\partial l(\hat{\theta})}{\partial \Phi_i}=\sum_j\sum_i {\frac{\partial l(\hat{\theta})}{\partial \hat{\theta}_j}\frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}} \approx \frac{\partial l(\hat{\theta})}{\partial \hat{\theta}_i}\\ \nabla_\Phi{l(\hat{\theta})}=\left| \begin{matrix} \partial l(\hat{\theta})/\partial \Phi_1\\ \partial l(\hat{\theta})/\partial \Phi_2\\ \vdots\\ \partial l(\hat{\theta})/\partial \Phi_i\\ \end{matrix} \right|=\left| \begin{matrix} \partial l(\hat{\theta})/\partial \hat{\theta}_1\\ \partial l(\hat{\theta})/\partial \hat{\theta}_2\\ \vdots\\ \partial l(\hat{\theta})/\partial \hat{\theta}_i\\ \end{matrix} \right|=\nabla_\hat{\theta}{l(\hat{\theta})} \]

      • 所以$ \nabla_\Phi{L(\Phi)} $可以化為:

        \[\nabla_\Phi{L(\Phi)}=\nabla_\Phi{\sum_{n=1}^{N}l^n(\hat{\theta}^n)}=\sum_{n=1}^{N}\nabla_\Phi{l^n(\hat{\theta}^n)}=\sum_{n=1}^{N}\nabla_{\hat{\theta}^n}{l^n(\hat{\theta}^n)} \]

        通過將二階微分近似為一階微分,提升運算效率的同時對模型預測的准確率沒有太大的影響。

  • Reptile

    Nichol A, Achiam J, Schulman J. On first-order meta-learning algorithms[J]. arXiv preprint arXiv:1803.02999, 2018.

    • 基本思想

      • 基於MAML進行改善,對參數更新次數不加限制。

    • Reptile VS Pretraining VS MAML


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM