無監督-DEEP GRAPH INFOMAX
標簽:圖神經網絡、無監督
動機
- 在真實世界中,圖的標簽是較少的,而現在圖神經的高性能主要依賴於有標簽的真是數據集
- 在無監督中,隨機游走犧牲了圖結構信息和強調的是鄰域信息,並且性能高度依賴於超參數的選擇
貢獻
- 在無監督學習上,首次結合互信息提出了一個圖節點表示學習方法-DGI
- 該方法不依賴隨機游走目標,並且使用與直推式學習和歸納學習
- DGI 依賴於最大限度地擴大圖增強表示和目前提取到的圖信息之間的互信息
思想
符號定義
節點特征集合: \(X \in \mathbb{R}^{N \times F}\),鄰接矩陣: \(A \in \mathbb{R}^{N \times N}\), 編碼器: \(\varepsilon~~~~ \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{N \times F'}\),鑒別器(discriminator): \(D~~ ~~ \mathbb{R}^{F} \times \mathbb{R}^{F} \rightarrow \mathbb{R}\),腐蝕函數(corruption function): \(C~~~\mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M}\),節點的表示 (patch representations): \(\overrightarrow{h_i}\) , 圖表示: \(\overrightarrow{s}\) .
核心
本質上利用大化局部互信息訓練一個模型 (編碼器) \(\varepsilon~~~~ \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{N \times F'}\) ,其損失函數 (1) 所示,負例的來源是是通過 corruption function 得到
框架
對於一個原圖 \(G(X, A)\),首先利用 corruption function 得到一個破壞后的圖 \(G'(\widetilde{X}, \widetilde{A})\),拿這兩個圖利用編碼器 \(\varepsilon\) 進行編碼, \(H = \varepsilon(X, A) = \{\overrightarrow{h_1},\overrightarrow{h_2},...,\overrightarrow{h_N}\} ~~~~~ H = \varepsilon(\widetilde{X}, \widetilde{A}) = \{\overrightarrow{\widetilde{h_1}},\overrightarrow{\widetilde{h_2}},...,\overrightarrow{\widetilde{h_N}}\}\) , 對於原圖得到每個節點的表示利用一個讀出函數 (readout function) 得到整個圖的表示 \(\overrightarrow{s} = R(H)\) ,最后利用目標函數更新參數
步驟
- 用 corruption function 進行采樣負樣例得到 \((\widetilde{X}, \widetilde{A}) \approx C(X,A)\)
- 將原圖(正例)喂給編碼器獲得節點的表示 patch representations \(\overrightarrow{h_i}\), \(H = \varepsilon(X, A) = \{\overrightarrow{h_1},\overrightarrow{h_2},...,\overrightarrow{h_N}\}\)
- 將破壞后的圖(負例)喂給編碼器獲得節點的表示 patch representations \(\overrightarrow{\widetilde{h_i}}\), \(H = \varepsilon(\widetilde{X}, \widetilde{A}) = \{\overrightarrow{\widetilde{h_1}},\overrightarrow{\widetilde{h_2}},...,\overrightarrow{\widetilde{h_N}}\}\)
- 通過讀出函數 (readout function) 傳遞輸入圖的patch representations 來得到總的圖的表示 \(\overrightarrow{s} = R(H)\)
- 通過應用梯度下降最大化 (1) 來更新 \(\varepsilon、R、D\) 的參數
損失函數
實驗
直推式學習 (Transductive Learn)
GCN 傳播規則: \(\varepsilon(X, A) = \sigma(\hat{D}^{-\frac{1}{2}}\hat{A} \hat{D}^{-\frac{1}{2}}X\Theta)\)
其中, \(\hat{A} = A + I_N\) 代表加上自環的鄰接矩陣, \(\hat{D}\) 代表相應的度矩陣,滿足 \(\hat{D_{ii} = \sum_{j}\hat{A_{ij}}}\) 對於非線性激活函數 \(\sigma\) ,選擇 PReLU(parametric ReLU)。\(\Theta \in R^{F \times F'}\) 是應用於每個節點的可學習線性變換。
對於 corruption function C ,直接采用 \(\widetilde{A} = A\),但是 \(\widetilde{X}\) 是由原本的特征矩陣 \(X\) 經過隨機變換得到的。也就是說,損壞的圖(corrupted graph)由與原始圖完全相同的節點組成,但它們位於圖中的不同位置,因此將得到不同的鄰近表示。
歸納式學習 (Inductive Learn)
對於歸納學習,不再在編碼器中使用 GCN 更新規則(因為學習的濾波器依賴於固定的和已知的鄰接矩陣);相反,我們應用平均池( mean-pooling)傳播規則,GraphSAGE-GCN:\(MP(X,A) = \hat{D}^{-1}\hat{A}X\Theta\)