【說在前面】本人博客新手一枚,象牙塔的老白,職業場的小白。以下內容僅為個人見解,歡迎批評指正,不喜勿噴![握手][握手]
【再啰嗦一下】本文銜接上兩個隨筆:人工智能中小樣本問題相關的系列模型演變及學習筆記(一):元學習、小樣本學習
【再啰嗦一下】本文銜接上兩個隨筆:人工智能中小樣本問題相關的系列模型演變及學習筆記(二):生成對抗網絡 GAN
【再啰嗦一下】本文銜接上兩個隨筆:人工智能中小樣本問題相關的系列模型演變及學習筆記(三):遷移學習
一、知識蒸餾綜述
知識蒸餾被廣泛的用於模型壓縮和遷移學習當中。
本文主要參考:模型壓縮中知識蒸餾技術原理及其發展現狀和展望
1. 基本概念
知識蒸餾可以將一個網絡的知識轉移到另一個網絡,兩個網絡可以是同構或者異構。做法是先訓練一個teacher網絡,然后使用這個teacher網絡的輸出和數據的真實標簽去訓練student網絡。
- 可以用來將網絡從大網絡轉化成一個小網絡,並保留接近於大網絡的性能。
- 可以將多個網絡的學到的知識轉移到一個網絡中,使得單個網絡的性能接近emsemble的結果。
2. 知識蒸餾的主要算法
知識蒸餾是對模型的能力進行遷移,根據遷移的方法不同可以簡單分為基於目標驅動的算法、基於特征匹配的算法兩個大的方向。
2.1 知識蒸餾基本框架
Hinton最早在文章“Distilling the knowledge in a neural network”中提出了知識蒸餾的概念,即knowledge distilling,對后續的許多算法都產生了影響,其框架示意圖如下:
從上圖中可以看出,包括一個teacher model和一個student model,teacher model需要預先訓練好,使用的就是標准分類softmax損失,但是它的輸出使用帶溫度參數T的softmax函數進行映射,如下:
當T=1時,就是softmax本身。當T>1,稱之為soft softmax,T越大,因為輸入 zk 產生的概率 f(zk) 差異就會越小。之所以要這么做,其背后的思想是:當訓練好一個模型之后,模型為所有的誤標簽都分配了很小的概率。然而實際上對於不同的錯誤標簽,其被分配的概率仍然可能存在數個量級的懸殊差距。這個差距,在softmax中直接就被忽略了,但這其實是一部分有用的信息。
訓練的時候小模型有兩個損失:一個是與真實標簽的softmax損失,一個是與teacher model的蒸餾損失,定義為KL散度。
當teacher model和student model各自的預測概率為pi,qi時,其蒸餾損失部分梯度傳播如下:
可以看出形式非常的簡單,梯度為兩者預測概率之差,這就是最簡單的知識蒸餾框架。
2.2 優化目標驅動的知識蒸餾框架
Hinton等人提出的框架是在模型最后的預測端,讓student模型學習到與teacher模型的知識,這可以稱之為直接使用優化目標進行驅動的框架,類似的還有ProjectionNet。
PrjojectNet同時訓練一個大模型和一個小模型,兩者的輸入都是樣本,其中大模型就是普通的CNN網絡,而小模型會對輸入首先進行特征投影。每一個投影矩陣P都對應了一個映射,由一個d-bit長的向量表示,其中每一個bit為0或者1,這是一個更加稀疏的表達。特征用這種方法簡化后自然就可以使用更加輕量的網絡的結構進行訓練。那么怎么完成這個過程呢?文中使用的是locality sensitive hashing(LSH)算法,這是一種聚類任務中常用的降維的算法。
優化目標包含了3部分,分別是大模型的損失,投影損失,以及大模型和小模型的預測損失,全部使用交叉熵,各自定義如下:
基於優化目標驅動的方法其思想是非常直觀,就是結果導向型,中間怎么實現的不關心,對它進行改進的一個有趣方向是GAN的運用。
2.3 特征匹配的知識蒸餾框架
結果導向型的知識蒸餾框架的具體細節是難以控制的,會讓訓練變得不穩定且緩慢。一種更直觀的方式是將teacher模型和student模型的特征進行約束,從而保證student模型確實繼承了teacher模型的知識,其中一個典型代表就是FitNets,FitNets將比較淺而寬的Teacher模型的知識遷移到更窄更深的Student模型上,框架如下:
FitNets背后的思想是,用網絡的中間層的特征進行匹配,不僅僅是在輸出端。它的訓練包含了兩個階段:
(1)第一階段就是根據Teacher模型的損失來指導預訓練Student模型。記Teacher網絡的某一中間層的權值Wt為Whint,意為指導的意思。Student網絡的某一中間層的權值Ws為Wguided,即被指導的意思,在訓練之初Student網絡進行隨機初始化。需要學習一個映射函數Wr使得Wguided的維度匹配Whint,得到Ws',並最小化兩者網絡輸出的MSE差異作為損失,如下:
(2)第二個訓練階段,就是對整個網絡進行知識蒸餾訓練,與上述Hinton等人提出的策略一致。不過FitNet直接將特征值進行了匹配,先驗約束太強,有的框架對激活值進行了歸一化。
基於特征空間進行匹配的方法其實是知識蒸餾的主流,類似的方法非常多,包括注意力機制的使用、類似於風格遷移算法的特征匹配等。
3. 知識蒸餾算法的展望
知識蒸餾還有非常多有意思的研究方向,這里我們介紹其中幾個。
3.1 不壓縮模型
機器學習模型要解決的問題如下,其中y是預測值,x是輸入,L是優化目標,θ1是優化參數。
因為深度學習模型沒有解析解,往往無法得到最優解,我們經常會通過添加一些正則項來促使模型達到更好的性能。
Born Again Neural Networks框架思想是通過增加同樣的模型架構,並且重新進行優化,以增加一個模型為例,要解決的問題如下:
具體的流程就是:
(1)訓練一個教師模型使其收斂到較好的局部值。
(2)對與教師模型結構相同的學生模型進行初始化,其優化目標包含兩部分,一部分是要匹配教師模型的輸出分布,比如采用KL散度。另一部分就是與教師模型訓練時同樣的目標,即數據集的預測真值。
然后通過下面這樣的流程,一步一步往下傳,所以被形象地命名為“born again”。
類似的框架還有Net2Net,network morphism等。
3.2 去掉 teacher 模型
一般知識蒸餾框架都需要包括一個Teacher模型和一個Student模型,而Deep mutual learning則沒有Teacher模型,它通過多個小模型進行協同訓練,框架示意圖如下。
Deep mutual learning在訓練的過程中讓兩個學生網絡相互學習,每一個網絡都有兩個損失。一個是任務本身的損失,另外一個就是KL散度。由於KL散度是非對稱的,所以兩個網絡的散度會不同。
相比單獨訓練,每一個模型可以取得更高的精度。值得注意的是,就算是兩個結構完全一樣的模型,也會學習到不同的特征表達。
3.3 與其他框架的結合
在進行知識蒸餾時,我們通常假設teacher模型有更好的性能,而student模型是一個壓縮版的模型,這不就是模型壓縮嗎?與模型剪枝,量化前后的模型對比是一樣的。所以知識蒸餾也被用於與相關技術進行結合,apprentice框架是一個代表。
網絡結構如上圖所示,Teacher模型是一個全精度模型,Apprentice模型是一個低精度模型。
4. 知識蒸餾在智能推薦中的應用
如果您對智能推薦感興趣,歡迎瀏覽我的另一篇博客:智能推薦算法演變及學習筆記 、CTR預估模型演變及學習筆記
本文主要參考:知識蒸餾在推薦系統中的應用
1. 基本概念
深度學習模型正在變得越來越復雜,網絡深度越來越深,模型參數量也在變得越來越多。而這會帶來一個現實應用的問題:將這種復雜模型推上線,模型響應速度太慢,當流量大的時候撐不住。
知識蒸餾就是目前一種比較流行的解決此類問題的技術方向。復雜笨重但是效果好的 Teacher 模型不上線,就單純是個導師角色,真正上戰場擋搶撐流量的是靈活輕巧的 Student 小模型。
在智能推薦中已經提到,一般有三個級聯的過程:召回、粗排和精排。
- 召回環節從海量物品庫里快速篩選部分用戶可能感興趣的物品,傳給粗排模塊。
- 粗排環節通常采取使用少量特征的簡單排序模型,對召回物料進行初步排序,並做截斷,進一步將物品集合縮小到合理數量,向后傳遞給精排模塊。
- 精排環節采用利用較多特征的復雜模型,對少量物品進行精准排序。
以上環節都可以采用知識蒸餾技術來優化性能和效果,這里的性能指的線上服務響應速度快,效果指的推薦質量好。
2. 精排環節采用知識蒸餾
精排環節注重精准排序,所以采用盡量多特征復雜模型,以期待獲得優質的個性化推薦結果。這也意味着復雜模型的在線服務響應變慢。
(1)在離線訓練的時候,可以訓練一個復雜精排模型作為 Teacher,一個結構較簡單的 DNN 排序模型作為 Student。
- 因為 Student 結構簡單,所以模型表達能力弱,於是,我們可以在 Student 訓練的時候,除了采用常規的 Ground Truth 訓練數據外,Teacher 也輔助 Student 的訓練,將 Teacher 復雜模型學到的一些知識遷移給 Student,增強其模型表達能力,以此加強其推薦效果。
(2)在模型上線服務的時候,並不用那個大 Teacher,而是使用小的 Student 作為線上服務精排模型,進行在線推理。
- 因為 Student 結構較為簡單,所以在線推理速度會大大快於復雜模型。
3. 精排環節蒸餾方法
(1)阿里媽媽在論文 "Rocket Launching: A Universal and Efficient Framework for Training Well-performing Light Net" 提出。
在精排環節采用知識蒸餾,主要采用 Teacher 和 Student 聯合訓練 ( Joint Learning ) 的方法。所謂聯合訓練,指的是在離線訓練 Student 模型的時候,增加復雜 Teacher 模型來輔助 Student,兩者同時進行訓練,是一種訓練過程中的輔導。
從網絡結構來說,Teacher 和 Student 模型共享底層特征 Embedding 層,Teacher 網絡具有層深更深、神經元更多的 MLP 隱層,而 Student 則由較少層深及神經元個數的 MLP 隱層構成,兩者的 MLP 部分參數各自私有。
(2)愛奇藝在排序階段提出了雙 DNN 排序模型,可以看作是在阿里的 rocket launching 模型基礎上的進一步改進。
為了進一步增強 student 的泛化能力,要求 student 的隱層 MLP 的激活也要學習 Teacher 對應隱層的響應,這點同樣可以通過在 student 的損失函數中加子項來實現。但是這會帶來一個問題,就是在 MLP 隱層復雜度方面,Student 和 Teacher 是相當的。那么,Teacher 相比 student,模型復雜在哪里呢?
這引出了第二點不同:雙 DNN 排序模型的 Teacher 在特征 Embedding 層和 MLP 層之間,可以比較靈活加入各種不同方法的特征組合功能。通過這種方式,體現 Teacher 模型的較強的模型表達和泛化能力。
4. 召回 / 粗排環節采用知識蒸餾
召回或者粗排環節,作為精排的前置環節,需要在准確性和速度方面找到一個平衡點,在保證一定推薦精准性的前提下,對物品進行粗篩,減小精排環節壓力。這兩個環節並不追求最高的推薦精度。畢竟在這兩個環節,如果准確性不足可以靠返回物品數量多來彌補。而模型小,速度快則是模型召回及粗排的重要目標之一。
- 用復雜的精排模型作為 Teacher,召回或粗排模型作為小的 Student,比如 FM 或者雙塔 DNN 模型等。
- 通過 Student 模型模擬精排模型的排序結果,可以使得前置兩個環節的優化目標和推薦任務的最終優化目標保持一致。
5. 召回/粗排環節蒸餾方法
作者給出了一些可能的處理方式,目前業內還沒定論。
(1)設想一:召回蒸餾的兩階段方法
(2)設想二:logits方法
(3)設想三:Without-Logits 方案
(4)設想四:Point Wise 蒸餾:Point Wise Loss 將學習問題簡化為單 Item 打分問題。
(5)設想五:Pair Wise 蒸餾:Pair Wise Loss 對能夠保持序關系的訓練數據對建模。
(6)設想六:List Wise 蒸餾:List Wise Loss 則對整個排序列表順序關系建模。
(7)設想七:聯合訓練召回、粗排及精排模型的設想
二、增量學習:補充介紹
主要關注的是災難性遺忘,平衡新知識與舊知識之間的關系。即如何在學習新知識的情況下不忘記舊知識。
引用Robipolikar對增量學習算法的定義,即一個增量學習算法應同時具有以下特點:
- 可以從新數據中學習新知識
- 以前已經處理過的數據不需要重復處理
- 每次只有一個訓練觀測樣本被看到和學習
- 學習新知識的同時能保持以前學習到的大部分知識
- 一旦學習完成后訓練觀測樣本被丟棄
- 學習系統沒有關於整個訓練樣本的先驗知識
在概念上,增量學習與遷移學習最大的區別就是對待舊知識的處理:
- 增量學習在學習新知識的同時需要盡可能保持舊知識,不管它們類別相關還是不相關的。
- 遷移學習只是借助舊知識來學習新知識,學習完成后只關注在新知識上的性能,不再考慮在舊知識上的性能。
關於這部分內容,未來有看到好的資料,再來分享。
如果您對異常檢測感興趣,歡迎瀏覽我的另一篇博客:異常檢測算法演變及學習筆記
如果您對智能推薦感興趣,歡迎瀏覽我的另一篇博客:智能推薦算法演變及學習筆記 、CTR預估模型演變及學習筆記
如果您對知識圖譜感興趣,歡迎瀏覽我的另一篇博客:行業知識圖譜的構建及應用、基於圖模型的智能推薦算法學習筆記
如果您對時間序列分析感興趣,歡迎瀏覽我的另一篇博客:時間序列分析中預測類問題下的建模方案 、深度學習中的序列模型演變及學習筆記
如果您對數據挖掘感興趣,歡迎瀏覽我的另一篇博客:數據挖掘比賽/項目全流程介紹 、機器學習中的聚類算法演變及學習筆記
如果您對人工智能算法感興趣,歡迎瀏覽我的另一篇博客:人工智能新手入門學習路線和學習資源合集(含AI綜述/python/機器學習/深度學習/tensorflow)、人工智能領域常用的開源框架和庫(含機器學習/深度學習/強化學習/知識圖譜/圖神經網絡)
如果你是計算機專業的應屆畢業生,歡迎瀏覽我的另外一篇博客:如果你是一個計算機領域的應屆生,你如何准備求職面試?
如果你是計算機專業的本科生,歡迎瀏覽我的另外一篇博客:如果你是一個計算機領域的本科生,你可以選擇學習什么?
如果你是計算機專業的研究生,歡迎瀏覽我的另外一篇博客:如果你是一個計算機領域的研究生,你可以選擇學習什么?
如果你對金融科技感興趣,歡迎瀏覽我的另一篇博客:如果你想了解金融科技,不妨先了解金融科技有哪些可能?
之后博主將持續分享各大算法的學習思路和學習筆記:hello world: 我的博客寫作思路