DPGN: Distribution Propagation Graph Network for Few-shot Learning


論文:《DPGN: Distribution Propagation Graph Network for Few-shot Learning》,CVPR2020

代碼:https://github.com/megvii-research/DPGN

一、概述

在給定少量標注數據(support集)的情況下,Few-shot learning旨在對未標注數據(query 集)進行預測。

有很多方法可以用於Few-shot learning任務,比如:

  • 微調(Fine-tuning)方法,但容易過擬合
  • 元學習(Meta-Learning)方法,但通常隱式利用樣本全局關系
  • 圖網絡(Graph Networks)方法,但只考慮了樣本對關系,忽略了重要的分布關系

如上圖所示,該論文提出了DPGN(Distribution Propagation Graph Network)模型,通過未標注數據和已標注數據之間的相似度分布,引導標簽信息在圖中更好地傳播。該模型包含點圖(Point Graph, PG)和分布圖(Distribution Graph, DP)兩個完全圖,分別用於建模每個樣本的實例級別表示和分布級別表示。具體的含義可以看方法部分。

總的來說,論文的創新點有三點:

  1. DPGN是第一個顯式利用分布進行標簽傳播的圖網絡Few-shot learning方法。
  2. 提出了雙完全圖架構,結合了實例級別和分布級別的關系。
  3. 在四個Few-shot learning數據集上進行了實驗,在分類任務上提升了5%12%的性能,並在半監督任務中提升了7%13%的性能。

二、方法

首先介紹Few-shot learning的問題定義,然后詳細介紹DPGN模型的細節。

1 問題定義

每個Few-shot learning任務都有一個support集\(\mathcal{S}\)和一個query集\(\mathcal{Q}\),二者都屬於訓練集\(\mathbb{D}^{train}\)\(\mathcal{S}=\{(x_1,y_1),\dots,(x_{N\times K},y_{N\times K})\}\)含有\(N\)個類別,每個類別有\(K\)個樣本(也就是\(N\)-way \(K\)-shot),\(\mathcal{Q}=\{(x_{N\times K + 1},y_{N\times K + 1}) \dots, (x_{N\times K + \bar{T}},y_{N\times K + \bar{T}})\}\)含有\(\bar{T}\)個樣本。在訓練階段,support集和query集的標簽都是已知的。在測試階段,模型需要根據測試集中的support集預測測試集中query集的標簽。

2 DPGN

上圖展示了DPGN模型的主要過程,該模型包含\(l\)層,每層包含一個點圖(PG) \(G_l^p=(V_l^p, E_l^p)\)和一個分布圖(DG) \(G_l^d=(V_l^d, E_l^d)\)。每一層的表示計算順序基本構成一個環,即\(E_l^p \rightarrow V_l^d \rightarrow E_l^d \rightarrow V_l^p \rightarrow E_{l+1}^p\)

為了進一步說明,節點集合\(V_l^p, V_l^d\)分別表示為\(V_l^p=\{v_{l,i}^p\}\)\(V_l^d=\{v_{l,i}^d\}\),邊集合\(E_l^p, E_l^d\)分別表示為\(E_l^p=\{e_{l,ij}^p\}\)\(E_l^d=\{e_{l,ij}^d\}\),其中\(i,j=1,\cdots,T\)\(T=N\times K + \bar{T}\)

\(v_{0,i}^p\)被初始化為特征提取器的輸出:

\[v_{0,i}^p = f_{emb}(x_i) \in \mathbb{R}^m \]

2.1 點到分布聚合

2.1.1 點相似度

PG中的每條邊都表示實例(點)之間的相似度,也就是樣本之間的相似度。

\(l=0\)時,PG的邊定義為:

\[e_{0,ij}^p=f_{e_0^p}((v_{0,i}^p - v_{0,j}^p)^2) \in \mathbb{R} \]

其中\(f_{e_0^p}:\mathbb{R}^m \rightarrow \mathbb{R}\)用於將向量映射為標量,論文使用兩個Conv-BN-ReLU塊實現。

\(l \gt 0\)時,PG的邊更新規則如下:

\[e_{l,ij}^p=f_{e_l^p}((v_{l-1,i}^p - v_{l-1,j}^p)^2) \cdot e_{l-1,ij}^p \in \mathbb{R} \]

在實際應用中還要對\(e_{l,ij}^p\)進行歸一化。

2.1.2 P2D聚合

生成了PG中的邊后,下一步就是生成DG中的節點表示。方法如上圖所示,DG中每個節點都是維度為\(N\times K\)的特征向量,其中第\(j\)維表示該實例\(x_i\)與實例\(x_j\)的關系,\(N\times K\)就是support集大小。

\(l=0\)時,DP的節點定義為:

\[v_{0,i}^d = \begin{cases} ||_{j=1}^{NK} \delta(y_i, y_j) \quad \text{if} \ x_i \ \text{is labeled} \\ [\frac{1}{NK},\cdots, \frac{1}{NK}] \quad \text{otherwise} \end{cases} \in \mathbb{R}^{NK} \]

其中\(||\)表示連接操作,\(\delta\)輸出0或1表示標簽\(y_i\)\(y_j\)是否相等。

\(l \gt 0\)時,DG的節點更新規則如下:

\[v_{0,i}^d = P2D(||_{j=1}^{NK} e_{l,ij}^p, v_{l-1,i}^p) \]

其中,\(P2D: (\mathbb{R}^{NK}, \mathbb{R}^{NK}) \rightarrow \mathbb{R}^{NK}\)是聚合網絡,論文使用全連接層加ReLU層實現。

2.2 分布到點聚合

2.2.1 分布相似度

DG中每條邊表示實例分布特征的相似度,也就是樣本在分布空間的相似度。

\(l=0\)時,DG的邊定義為:

\[e_{0,ij}^d = f_{e_0^d}((v_{0,i}^d - v_{0,j}^d)^2) \in \mathbb{R} \]

其中,\(f_{e_0^d}: \mathbb{R}^{NK} \rightarrow \mathbb{R}\)用於將向量映射為標量,論文使用兩個Conv-BN-ReLU塊實現。

\(l \gt 0\)時,DG中邊更新規則如下:

\[e_{l,ij}^d = f_{e_l^d}((v_{l,i}^d - v_{l,j}^d)^2) \cdot e_{l-1,ij}^d \in \mathbb{R} \]

同樣需要對\(e_{l,ij}^d\)進行正則化。

2.2.2 D2P聚合

接下來就是利用DG中的邊特征,也就是樣本的分布相似度,生成PG中的節點特征:

\[v_{l,i}^p = D2P(\sum_{j=1}^T(e_{l,ij}^p \cdot v_{l-1,j}^p), v_{l-1,i}^p) \in \mathbb{R}^m \]

其中,\(D2P: (\mathbb{R}^m, \mathbb{R}^m) \rightarrow \mathbb{R}^m\),論文使用兩個Conv-BN-ReLU塊實現。

3 訓練

為了進行節點分類,只需要將最后一層的邊特征輸入softmax函數即可:

\[P(\hat{y_i}|x_i) = \text{Softmax}(\sum_{j=1}^{NK}e_{l,ij}^p \cdot one\_hot(y_j)) \]

其中,\(P(\hat{y_i}|x_i)\)就是樣本\(x_i\)的預測概率分布,\(y_j\)是support集中第\(j\)個樣本的標簽,\(e_{l,ij}^p\)表示DPGN最后一層PG中的邊特征。

3.1 點損失

點損失就是對節點進行分類的交叉熵損失:

\[\mathcal{L}_l^p = \mathcal{L}_{CE}(P(\hat{y_i}|x_i),y_i) \]

其中,\(\mathcal{L}_{CE}\)是交叉熵函數,\(y_i\)\(x_i\)的標簽。

3.2 分布損失

分布損失實際上是在DG層面做節點分類:

\[\mathcal{L}_l^d = \mathcal{L}_{CE}(\text{Softmax}(\sum_{j=1}^{NK}e_{l,ij}^d \cdot one\_hot(y_j)),y_i) \]

模型最終的損失函數由每一層的兩部分損失得到:

\[\mathcal{L} = \sum_{l=1}^{\hat{l}}(\lambda_p \mathcal{L}_l^p + \lambda_d \mathcal{L}_l^d) \]

其中\(\hat{l}\)表示DPGN總的層數,\(\lambda_p,\lambda_d\)是權重參數。

三、實驗

論文使用了四個Few-shot learning數據集

下面展示一個數據集的實驗結果,其他數據集結果可以參照原論文


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM