論文:《Memory-based Graph Networks》,ICLR2020
代碼:https://github.com/amirkhas/GraphMemoryNet
概述
圖神經網絡(GNNs)是一類深度模型,可處理任意拓撲結構的數據。比如社交網絡、知識圖譜、分子結構等。GNNs通常被用來根據節點的交互關系學習節點的向量表示,典型的模型有gated GNN(Li et al., 2015)、MPNN(Giler et al., 2017)、GCN(Kipf & Welling, 2016)和GAT(Velikovi et al., 2018)。GNNs方法通常優於傳統的隨機游走、矩陣分解、核方法和概率圖模型。
但是,這些模型無法學習到層次表示,因為它們沒有利用圖的組合性質。DiffPool (Ying et al., 2018)、TopKPool (Gao & Ji, 2019)、SAGPool (Lee et al., 2019)等模型引入參數化的圖池化層,通過堆疊交錯層和池化層來學習層次圖表示。但這些模型的計算效率不高,因為它們需要在每個池化層后進行消息傳遞計算。
本論文介紹了一個能夠同時進行圖表示學習和節點聚類的記憶層,該記憶層由多組(multi-head)記憶鍵和卷積運算組成。記憶鍵被視為聚類中心,而卷積運算用來聚合多組結果。記憶層的輸入叫做query,是前一層輸出的節點表示,記憶層的輸出是聚類后的節點表示。這種記憶層不顯式依賴節點的連接信息,因此不存在過度平滑問題(Xu et al., 2018),同時也改進了效率和性能。
作者在論文中提出了兩種基於記憶層的網絡,分別叫做memory-based GNN(MemGNN)和graph memory network(GMN)。其中MemGNN就是首先使用GNN學習節點的初始表示然后堆疊記憶層學習層次表示;GMN則不依賴GNN,因此也不需要消息傳遞的計算。
相關工作
方法
下面開始講記憶層究竟是什么,以及由此而來的兩種網絡架構,即GMN和MemGNN。
記憶層
第\(l\)層的記憶層可以表示為\(\mathcal{M}^{(l)}:\mathbb{R}^{n_l \times d_l} \longmapsto \mathbb{R}^{n_{l+1} \times d_{l+1}}\),記憶層輸入\(n_l\)個維度為\(d_l\)的查詢向量,生成\(n_{l+1}\)個維度為\(d_{l+1}\)的查詢向量(下個記憶層的查詢向量)。因為要自底向上學習圖層次表示,要保證\(n_{l+1} \lt n_l\)。
上圖就是記憶層的示意圖,假設其中有\(|h|\)組記憶鍵。現在來看看記憶層是怎么實現聚類的。首先,假設第\(l\)層記憶層的輸入為\(\mathbf{Q}^{(l)} \in \mathbb{R}^{n_l \times d_l}\),一組記憶鍵\(\mathbf{K}^{(l)} \in \mathbb{R}^{n_{l+1} \times d_l}\)可以看作是\(\mathbf{Q}^{(l)}\)的聚類中心。為了衡量\(\mathbf{Q}^{(l)}\)和\(\mathbf{K}^{(l)}\)每個分量之間的相似度,作者借鑒Xie et al., 2016的工作,使用t分布作為核函數。因此查詢\(q_i\)和記憶鍵\(k_j\)的正則化的相似度定義為:
\(C_{i,j}\)就是將節點\(i\)分配到類簇\(j\)的概率,或者說\(q_i\)和\(k_j\)之間的注意力權重。\(\tau\)是t分布的自由度。前面我們說到,記憶鍵總共有\(|h|\)組,因此實際上上述聚類要計算\(|h|\)次,得到結果為\([\mathbf{C}_0^{(l)} \dots \mathbf{C}_{|h|}^{(l)}] \in \mathbb{R}^{|h| \times n_{l+1} \times n_l}\)。為了將\(h\)組結果聚合為一組結果,作者將三個維度分別看作深度、高度和寬度,然后使用一個\(1 \times 1\)的卷積進行聚合:
其中,\(\Gamma_{\phi}\)是\(1 \times 1\)的卷積,\(\mathbf{C}^{(l)}\)就是聚合后的分配矩陣。
之后,值(value)矩陣\(\mathbf{V}^{(l)} \in \mathbb{R}^{n_{l+1} \times d_l}\)由下式定義:
由於\(\mathbf{V}^{(l)}\)元素維度和\(\mathbf{Q}^{(l)}\)元素維度相同,作者認為這就表示在相同空間對節點聚類,之后還要經過一個單層前向網絡將\(\mathbf{V}^{(l)}\)投影為新的查詢:
其中\(\sigma\)是LeankyReLU激活函數。\(\mathbf{Q}^{(l+1)}\)將作為下一個記憶層的查詢。
對於圖分類任務,我們可以通過堆疊記憶層最終獲得整個圖的向量表示,然后用全連接層進行分類:
其中,\(\mathbf{Q}^{(0)}=f_q(g)\)是將圖\(g\)輸入網絡\(f_g\)得到的初始查詢表示,也就是初始節點向量。根據\(f_q\)的不同,作者引出了兩種模型,即GMN和MemGNN。
GMN架構
GMN將圖中節點表示視為排列不變(permutation-invariant)集,也就是不考慮它們之間的空間關系,因此也不需要使用到圖神經網絡中的消息傳遞機制。但是,圖中節點畢竟是存在拓撲關系的,完全不考慮是行不通的,因此作者考慮的是把節點的拓撲關系編碼到節點的初始表示中。更具體地說,作者使用帶重啟的隨機游走(RWR)(Pan et al., 2004)來計算拓撲嵌入,然后按行對它們進行排序,以強制節點嵌入保持順序不變。得到包含拓撲信息的節點表示\(\mathbf{X} \in \mathbb{R}^{n \times d_{in}}\)后,初始的查詢表示通過兩層前向網絡計算得到:
其中\(\mathbf{W}_0 \in \mathbb{R}^{n\times d_{in}}\)和\(\mathbf{W}_1 \in \mathbb{R}^{2d_{in}\times d_{0}}\)是參數,\(\mathbf{S} \in \mathbb{R}^{n\times n}\)是圖擴散矩陣,\(\Vert\)表示拼接操作,\(\sigma\)是LeakyReLU激活函數。
MemGNN架構
MemGNN直接使用圖神經網絡計算初始查詢:
其中,\(G_{\theta}\)是任意的圖神經網絡。作者在實現時使用了GAT模型的改進版e-GAT,也就是在計算注意力權重時考慮了邊特征。注意力權重計算公式為:
其中\(h_i^{(l)}, h_{i \rightarrow j}^{(l)}\)分別是節點表示和邊表示,\(\mathbf{W}_n, \mathbf{W}_e\)分別是節點權重和邊權重,\(\mathbf{W}\)是前向網絡參數,\(\sigma\)是LeakyReLU激活函數。
模型訓練
模型的損失包含兩部分,有監督損失和無監督損失。有監督損失\(\mathcal{L}_{sup}\)來自圖分類或者圖回歸損失。無監督損失用於鼓勵模型學習利於聚類的表示,由\(\mathbf{C}^{(l)}\)和輔助分布\(\mathbf{P}^{(l)}\)之間的KL散度定義:
其中輔助分布\(\mathbf{P}^{(l)}\)的計算和Xie et al., 2016一樣,
因此模型最終的損失定義為
為了使訓練更穩定,\(\mathcal{L}_{sup}\)產生的的梯度每個batch進行反向傳播,而\(\mathcal{L}_{KL}^{(l)}\)產生的梯度每個epoch反向傳播一次,可以通過反復調整\(\lambda\)的取值為0或1實現。這是因為快速地調整聚類中心,也就是記憶鍵,可能會導致訓練不穩定。
實驗
論文主要關注圖分類和圖回歸任務,使用了5個圖分類數據集和2個圖回歸數據集:
主要實驗結果如下面幾幅圖所示: