論文:《DPGN: Distribution Propagation Graph Network for Few-shot Learning》,CVPR2020
代碼:https://github.com/megvii-research/DPGN
一、概述
在給定少量標注數據(support集)的情況下,Few-shot learning旨在對未標注數據(query 集)進行預測。
有很多方法可以用於Few-shot learning任務,比如:
- 微調(Fine-tuning)方法,但容易過擬合
- 元學習(Meta-Learning)方法,但通常隱式利用樣本全局關系
- 圖網絡(Graph Networks)方法,但只考慮了樣本對關系,忽略了重要的分布關系
如上圖所示,該論文提出了DPGN(Distribution Propagation Graph Network)模型,通過未標注數據和已標注數據之間的相似度分布,引導標簽信息在圖中更好地傳播。該模型包含點圖(Point Graph, PG)和分布圖(Distribution Graph, DP)兩個完全圖,分別用於建模每個樣本的實例級別表示和分布級別表示。具體的含義可以看方法部分。
總的來說,論文的創新點有三點:
- DPGN是第一個顯式利用分布進行標簽傳播的圖網絡Few-shot learning方法。
- 提出了雙完全圖架構,結合了實例級別和分布級別的關系。
- 在四個Few-shot learning數據集上進行了實驗,在分類任務上提升了5%12%的性能,並在半監督任務中提升了7%13%的性能。
二、方法
首先介紹Few-shot learning的問題定義,然后詳細介紹DPGN模型的細節。
1 問題定義
每個Few-shot learning任務都有一個support集\(\mathcal{S}\)和一個query集\(\mathcal{Q}\),二者都屬於訓練集\(\mathbb{D}^{train}\)。\(\mathcal{S}=\{(x_1,y_1),\dots,(x_{N\times K},y_{N\times K})\}\)含有\(N\)個類別,每個類別有\(K\)個樣本(也就是\(N\)-way \(K\)-shot),\(\mathcal{Q}=\{(x_{N\times K + 1},y_{N\times K + 1}) \dots, (x_{N\times K + \bar{T}},y_{N\times K + \bar{T}})\}\)含有\(\bar{T}\)個樣本。在訓練階段,support集和query集的標簽都是已知的。在測試階段,模型需要根據測試集中的support集預測測試集中query集的標簽。
2 DPGN
上圖展示了DPGN模型的主要過程,該模型包含\(l\)層,每層包含一個點圖(PG) \(G_l^p=(V_l^p, E_l^p)\)和一個分布圖(DG) \(G_l^d=(V_l^d, E_l^d)\)。每一層的表示計算順序基本構成一個環,即\(E_l^p \rightarrow V_l^d \rightarrow E_l^d \rightarrow V_l^p \rightarrow E_{l+1}^p\)。
為了進一步說明,節點集合\(V_l^p, V_l^d\)分別表示為\(V_l^p=\{v_{l,i}^p\}\),\(V_l^d=\{v_{l,i}^d\}\),邊集合\(E_l^p, E_l^d\)分別表示為\(E_l^p=\{e_{l,ij}^p\}\),\(E_l^d=\{e_{l,ij}^d\}\),其中\(i,j=1,\cdots,T\),\(T=N\times K + \bar{T}\)。
\(v_{0,i}^p\)被初始化為特征提取器的輸出:
2.1 點到分布聚合
2.1.1 點相似度
PG中的每條邊都表示實例(點)之間的相似度,也就是樣本之間的相似度。
當\(l=0\)時,PG的邊定義為:
其中\(f_{e_0^p}:\mathbb{R}^m \rightarrow \mathbb{R}\)用於將向量映射為標量,論文使用兩個Conv-BN-ReLU塊實現。
當\(l \gt 0\)時,PG的邊更新規則如下:
在實際應用中還要對\(e_{l,ij}^p\)進行歸一化。
2.1.2 P2D聚合
生成了PG中的邊后,下一步就是生成DG中的節點表示。方法如上圖所示,DG中每個節點都是維度為\(N\times K\)的特征向量,其中第\(j\)維表示該實例\(x_i\)與實例\(x_j\)的關系,\(N\times K\)就是support集大小。
當\(l=0\)時,DP的節點定義為:
其中\(||\)表示連接操作,\(\delta\)輸出0或1表示標簽\(y_i\)和\(y_j\)是否相等。
當\(l \gt 0\)時,DG的節點更新規則如下:
其中,\(P2D: (\mathbb{R}^{NK}, \mathbb{R}^{NK}) \rightarrow \mathbb{R}^{NK}\)是聚合網絡,論文使用全連接層加ReLU層實現。
2.2 分布到點聚合
2.2.1 分布相似度
DG中每條邊表示實例分布特征的相似度,也就是樣本在分布空間的相似度。
當\(l=0\)時,DG的邊定義為:
其中,\(f_{e_0^d}: \mathbb{R}^{NK} \rightarrow \mathbb{R}\)用於將向量映射為標量,論文使用兩個Conv-BN-ReLU塊實現。
當\(l \gt 0\)時,DG中邊更新規則如下:
同樣需要對\(e_{l,ij}^d\)進行正則化。
2.2.2 D2P聚合
接下來就是利用DG中的邊特征,也就是樣本的分布相似度,生成PG中的節點特征:
其中,\(D2P: (\mathbb{R}^m, \mathbb{R}^m) \rightarrow \mathbb{R}^m\),論文使用兩個Conv-BN-ReLU塊實現。
3 訓練
為了進行節點分類,只需要將最后一層的邊特征輸入softmax函數即可:
其中,\(P(\hat{y_i}|x_i)\)就是樣本\(x_i\)的預測概率分布,\(y_j\)是support集中第\(j\)個樣本的標簽,\(e_{l,ij}^p\)表示DPGN最后一層PG中的邊特征。
3.1 點損失
點損失就是對節點進行分類的交叉熵損失:
其中,\(\mathcal{L}_{CE}\)是交叉熵函數,\(y_i\)是\(x_i\)的標簽。
3.2 分布損失
分布損失實際上是在DG層面做節點分類:
模型最終的損失函數由每一層的兩部分損失得到:
其中\(\hat{l}\)表示DPGN總的層數,\(\lambda_p,\lambda_d\)是權重參數。
三、實驗
論文使用了四個Few-shot learning數據集
下面展示一個數據集的實驗結果,其他數據集結果可以參照原論文