論文筆記:Adaptive Consistency Regularization for Semi-Supervised Transfer Learning
Paper: Adaptive Consistency Regularization for Semi-Supervised Transfer Learning, CVPR 2021
Code: https://github.com/SHI-Labs/Semi-Supervised-Transfer-Learning
Try to accomplish
以往的半監督學習所用的模型大多都是隨機初始化的,因為如果使用了預訓練模型,半監督就很難帶來很多性能提升,然而那些只是簡單地將兩種學習方法結合,這篇論文里提出了一個新的將遷移學習和半監督學習結合的模型,在保留了預訓練模型的知識的基礎上,也讓半監督學習發揮出相應的性能提升。
學習目標
整個模型的結構圖如下:
主要就是2條網絡分支以及4個子模塊,2個網絡分支分別為 source 模型和 target 模型,source 模型就是預訓練模型,在后續學習中不會變化,而 target 模型就是要學習的模型。4個子模塊即圖上4個框起來的1個 loss 和3個 consistency,接下來一個個進行解讀。
首先明確模型的目標是什么,半監督學習希望能夠將大量的無監督數據能夠幫助有監督學習的過程,顯然一個基本的要求就是有監督的部分首先能夠學好,無監督數據作為正則項來對模型的偏好進行調整,所以整體的目標如下所示:
這里把模型分為了兩部分,前半部分是特征提取器 \(F\),模型參數為 \(\theta\),后半部分是任務相關模型 \(G\),比如分類問題就可以是一層全連接層,參數為 \(\phi\)。
模型目標整體上也分為兩部分,一個是有監督部分,這里以分類問題為例,可以是最小化有監督數據的 cross entropy,后半部分是一個關於特征提取器的正則項,無監督數據主要就是對這個正則項貢獻。
接下來會主要對圖上的4個模塊進行說明。
Supervised Loss
有監督部分其實沒什么太多說的內容,就是簡單地將 batch 中有標簽的數據計算下交叉熵(這篇論文主要針對的是分類問題)。需要說明的一點是關於模型的初始化,因為本篇論文主要的一個想法就是想將遷移學習和半監督學習結合起來,所以模型的初始化也沿用遷移學習的思路,特征提取器 \(F\) 直接就用預訓練模型的特征提取器的參數。
關於任務相關模型 \(G\) 的參數並非就是隨機初始化了事,而是采用了一個 Imprint 的做法,這篇論文我沒有看,大概看了下代碼,看起來是將有監督數據全部過一下預訓練模型,然后將同一類別的輸出取了一個平均,有點類似於“類別中心”的概念,然后將這個“類別中心”賦給 \(G\) 作為其初始化參數。
除此之外還需要說明的一點是,模型圖中上面的那條分支(source feature extractor 和 source classifier)存的就是預訓練模型參數,這條分支的參數是一直不更新的,只有下面的 Target 模型的參數會學習。
Adaptive Knowledge Consistency (AKC)
這是本文設計的第一個正則項,其思想為,源模型的F是具有泛性的(遷移學習能work的原因),為了保留這個泛性,目標模型的F提取的特征應該與源模型的特征要相同或盡可能相近,這里可以利用包括有監督無監督數據一塊丟進來訓練,有點類似於知識蒸餾。
簡單點說,就是讓 source 模型和 target 模型的特征提取器提取的特征要一致。 就像模型圖上的,AKC模塊的輸入是兩個模型提取出來的 feature。
但如果要求提取完全一致直接就讓 target 模型的 \(F\) 的參數等於 source 模型的不就好了嗎?其實並非要完全一致,因為預訓練模型提取的特征未必就完全適用於當前這個任務,那么能否給出一個標准,什么樣的特征才是適用於當前這個任務的?
這里給出了一個方案,就是當前任務的數據(這里不需要用到標簽,無監督數據也可以參與進來),經過了 source 模型后,其輸出的數據分布 logits 區分度很高,即某個類別的概率遠高於其他類別的概率,我們就可以認為當前這個特征提取的比較好。
怎么量化這個區分度,這里就用 logits 的熵,熵越高就越接近均勻分布,這個樣本的置信度不高,所以就不要要求 target 模型的特征和 source 一致,反之,如果熵小,就說明區分度可能比較高,那么就讓 target 模型和 source 模型的特征一致。故而可以設計一個熵門控函數。AKC的正則項如下:
這里用帶權 KL 散度來計算AKC正則項,從公式中可以看到 \(x\) 既可以是有監督數據也可以是無監督數據,權重 \(w\) 是一個關於熵的門控函數,如下:
只有當 source 模型預測結果的熵小於某個閾值時,\(w\) 才會為1。這就是模型圖中 AKC 模塊的3條輸入來源。
Adaptive Representation Consistency (ARC)
接下來是第二個正則項,ARC的出發點是無監督數據包含了數據結構信息,應該利用上這部分信息去指導有監督數據訓練,即盡可能地拉近無監督數據和有監督數據的數據分布,首先現引入描述兩個分布的一個經典度量 Maximum Mean Discrepancies (MMD):
即給定兩個數據集,計算出兩個數據集中的數據對應的數據分布之間的差異性,本篇論文中就是要盡可能最小化有監督數據的特征集和無監督數據的特征集的 MMD。
但存在一個問題,早期由於模型還沒訓練好,所以無標簽數據的分布是不准確的。這時候也不應該要求兩個集合保持一致,否則可能就影響了有監督數據的訓練。
這里采用與AKC類似的思路,只對置信度高的樣本,即 target 模型如果對數據預測結果的熵夠低,我們才認為這個特征值得參與進ARC的正則項計算,才進行計算MMD,這里的樣本即包含有監督也包含無監督。
所以開辟兩個高置信度的有監督數據特征集以及無監督數據特征集:
考慮到Mini-batch中的樣本量不足以確定一個分布,這里開辟了有監督緩沖區和無監督緩沖區來存放最近選擇的高置信度樣本。每次都會將batch中計算的高置信度有監督和無監督樣本格子添加到相應的緩沖區中,然后從緩沖區中抽取最新的 k 個構成有監督特征集以及無監督特征集,進行計算 MMD 得到 ARC 正則項:
Semi-Supervised Consistency
模型圖中的最后一個模塊,這里直接就是應用了別人提出的各種半監督的一致性損失,例如 MixMatch、FixMatch、Pseudo-labeling 等等,因為本論文提出的 AKC 和 ARC 只是兩個正則項,是可以和別的半監督方法一起使用的,故而最終的學習目標(損失函數)為:
Results
以下是在3個分類數據集上和不同半監督方法的對比:
可以看到兩個正則項和一些經典的半監督方法比都能旗鼓相當,而將其與這些半監督方法一塊使用時,效果更是提升了不少,證明這兩個正則項還是有一定效果的。