基本信息
論文題目:GRAPH ATTENTION NETWORKS
時間:2018
期刊:ICLR
主要動機
探討圖譜(Graph)作為輸入的情況下如何用深度學習完成分類、預測等問題;通過堆疊這種層(層中的頂點會注意鄰居的特征),我們可以給鄰居中的頂點指定不同的權重,不需要任何一種耗時的矩陣操作(比如求逆)或依賴圖結構的先驗知識。
CNN 結構可以有效用於解決網格狀的結構數據,例如圖像分類等。但是現有的許多任務的數據並不能表示為網格狀的結構,而是分布在不規則的區域,如社交網絡、生物網絡等。這樣的數據通常用圖的形式來表示。
目前有些研究通過擴展神經網絡來處理不規則結構圖,包括循環神經網絡(RNN)、圖神經網絡(GNN)及其改進的模型。另一種研究思路是把卷積泛化到圖域中,分為譜方法和非譜方法兩種。
注意力機制的一個好處是可以處理可變大小的輸入,並且關注輸入的最相關部分以做出決策。當注意機制用於計算單個序列的表示時,通常將其稱為自注意或內注意。該機制和 RNN 結合已經廣泛應用於機器閱讀、句子表示和機器翻譯領域。
作者提出了一種基於注意機制的架構,能夠完成圖結構數據的節點分類。該方法的思路是通過注意其鄰居節點,計算圖中的每個節點的隱藏表征,還帶有自注意策略。這種架構有多重性質:
- 運算高效,因為它可以在 “頂點 - 鄰居” 對上並行計算;
- 可以通過對近鄰節點指定任意的權重應用於不同度的圖節點;
- 該模型直接適用於歸納學習問題,其中包括需要將模型泛化為此前為見的圖的任務。
GAT 架構
單個 graph attentional layer 的輸入是一個節點特征向量集合,\(h=\lbrace \vec{h_1},\vec{h_2},\dots,\vec{h_N} \rbrace,\; \vec{h_i}\in R^F\),其中 \(N\) 表示節點的數目,\(F\) 表示每個節點的特征的數目。並生成一個新的節點特征集合 \(h^{'}=\lbrace \vec{h_1^{'}},\vec{h_2^{'}},\dots,\vec{h_N^{'}} \rbrace,\; \vec{h_i^{'}}\in R^{F^{'}}\) 作為輸出,其中 \(F\) 和 \(F^{'}\) 具有不同的維度。
為了獲得足夠的表達能力以將輸入特征變換為更高級別的特征,需要至少一個可學習的線性變換。為此,作為初始步驟,一個共享的線性變換參數矩陣 \(W\in R^{F^{'}\times F}\) 被應用於每一個節點。然后執行 self-attention 處理:
其中,\(a\) 是一個 \(R^{F^{'}}\times R^{F^{'}}\to R\) 的映射,公式(1)表示了節點 \(j\) 的特征對於節點 \(i\) 的重要性。一般來說,self-attention 會將注意力分配到圖中所有的節點上,這種做法顯然會丟失結構信息。為了解決這一問題,作者使用了一種 masked attention 的方法 -- 僅將注意力分配到節點 \(i\) 的鄰居節點集上,即 \(N_i\),其中節點 \(i\) 也包括在 \(N_i\) 中。為了使系數在不同節點之間易於比較,我們使用 softmax 函數在 \(j\) 的所有選擇中對它們進行標准化:
注意力機制 \(a\) 是一個單層前饋神經網絡,其中 \(\vec{a}\in R^{2F^{'}}\) 是權重參數,使用 LeakyReLU 作為激活函數。完全展開后,由注意機制計算的系數可以表示為:
歸一化的注意力系數用於計算與它們對應的特征的線性組合,以用作每個節點的最終輸出特征,采用非線性的函數:
為了提高模型的擬合能力,在本文中還引入了多抽頭的 self-attention(如圖 1 右側部分。與《Attention is All You Need》一致),即同時使用多個 \(W^k\) 計算 self-attention,然后將各個計算得到的結果進行合並(連接或求和):
模型比較
上一節描述的圖注意力層直接解決了之前在圖結構上使用神經網絡建模的方法的幾個問題:
- 計算高效:self-attention 層的操作可以在所有的邊上並行,輸出特征的計算可以在所有頂點上並行。沒有耗時的特征值分解。單層的 GAT 的時間復雜度為 \(O(|V|FF^{'}+|E|F^{'})\) 。盡管 multi-head 注意力將存儲和參數要求乘以系數 K,但是單個 head 的計算完全獨立且可以並行化。
- 與 GCN 相反,我們的模型允許(隱式)為同一鄰域的節點分配不同的重要性,從而實現模型表示能力的飛躍;
- 對於圖中的所有邊,attention 機制是共享的。因此 GAT 也是一種局部模型。也就是說,在使用 GAT 時,我們無需訪問整個圖,而只需要訪問所關注節點的鄰節點即可。這一特點的作用主要有:(1)可以處理有向圖(若 \(j\to i\) 不存在,僅需忽略 \(\alpha_{ij}\) 即可);(2)可以被直接用於進行歸納學習。
- 最新的歸納學習方法(GraphSAGE 2017)通過從每個節點的鄰居中抽取固定數量的節點,從而保證其計算的一致性。這意味着,在執行推斷時,我們無法訪問所有的鄰居。然而,本文所提出的模型是建立在所有鄰節點上的,而且無需假設任何節點順序。
我們能夠生成一個利用稀疏矩陣運算的 GAT 層版本,將存儲復雜性降低到節點和邊緣數量的線性,並在較大的圖形數據集上實現 GAT 模型。然而,我們使用的張量操作框架僅支持秩 - 2 張量的稀疏矩陣乘法,這限制了當前實現的層的批處理能力(特別是對於具有多個圖的數據集),適當地解決這一限制是未來工作的重要方向。根據現有圖形結構的規律性,在稀疏場景中,GPU 相比於 CPU 可能無法提供主要的性能優勢。
實驗評估
歸納學習(Inductive Learning):先從訓練樣本中學習到一定的模式,然后利用其對測試樣本進行預測(即首先從特殊到一般,然后再從一般到特殊),這類模型如常見的貝葉斯模型。
演繹學習(Transductive Learning):先觀察特定的訓練樣本,然后對特定的測試樣本做出預測(從特殊到特殊),這類模型如 k 近鄰、SVM 等。
在演繹學習中使用三個標准的引證網絡數據集——Cora、Citeseer 與 Pubmed。在這些數據集中,節點對應於文檔,邊(無向的)對應於引用關系。節點特征對應於文檔的 Bag of Words
表示。每個節點擁有一個類別標簽(在分類時使用 softmax 激活函數)。每個數據集的詳細信息如下表所示:
演繹學習的實驗結果如下表所示,可以看到,GAT 模型的效果要基本優於其他模型:
對於歸納學習,本文使用了一個蛋白質關聯數據集(protein-protein interaction, PPI),在其中,每張圖對應於人類的不同組織。此時,使用 20 張圖進行訓練,2 張圖進行驗證,2 張圖用於測試。每個節點可能的標簽數為 121 個,而且,每個節點可以同時擁有多個標簽(在分類時使用 sigmoid 激活函數),其實驗結果如下: