Graph-GraphSage


MPNN很好地概括了空域卷積的過程,但定義在這個框架下的所有模型都有一個共同的缺陷:

1. 卷積操作針對的對象是整張圖,也就意味着要將所有結點放入內存/顯存中,才能進行卷積操作。但對實際場景中的大規模圖而言,整個圖上的卷積操作並不現實。GraphSage[2]提出的動機之一就是解決這個問題。從該方法的名字我們也能看出,區別於傳統的全圖卷積,GraphSage利用采樣(Sample)部分結點的方式進行學習。當然,即使不需要整張圖同時卷積,GraphSage仍然需要聚合鄰居結點的信息,即論文中定義的aggregate的操作。這種操作類似於MPNN中的消息傳遞過程。

2. 之前的方法是固有的直推式得到node embeddeding,不能泛化,無法有效適應動態圖中新增節點的特性, 往往需要從頭訓練或至少局部重訓練。

3 缺乏權值共享(Deepwalk, LINE, node2vec)。節點的embedding直接是一個N*d的矩陣, 互相之間沒有共享學習參數。

4.輸入維度固定為|V|。無論是基於skip-gram的淺層模型還是基於autoencoder的深層模型,輸入的維度都是點集的大小。訓練過程依賴點集信息的固定網絡結構限制了模型泛化到動態圖的能力,無法為新加入節點生成embedding。


GraphSage 的模型

 

文中不是對每個頂點都訓練一個單獨的embeddding向量,而是訓練了一組aggregator functions,這些函數學習如何從一個頂點的局部鄰居聚合特征信息(見圖1)。每個聚合函數從一個頂點的不同的hops或者說不同的搜索深度聚合信息。測試或是推斷的時候,使用訓練好的系統,通過學習到的聚合函數來對完全未見過的頂點生成embedding。

 一種鄰居特征聚集的方法:

1.鄰居采樣(sample neighborhood)。因為每個節點的度是不一致的,為了計算高效, 為每個節點采樣固定數量的鄰居。

2.鄰居特征聚集(aggregate feature information from neighbors)。通過聚集采樣到的鄰居特征,更新當前節點的特征。網絡第k層聚集到的 鄰居即為BFS過程第k層的鄰居

3. 訓練。既可以用獲得的embedding預測節點的上下文信息(context),也可以利用embedding做有監督訓練。

采樣方法:

  1. 在圖中隨機采樣若干個結點,結點數為傳統任務中的batch_size。對於每個結點,隨機選擇固定數目的鄰居結點(receptive field)(這里鄰居不一定是一階鄰居,也可以是二階鄰居)構成進行卷積操作的圖。
  2. 將鄰居結點的信息通過aggregate函數聚合起來更新剛才采樣的結點。
  3. 計算采樣結點處的損失。如果是無監督任務,我們希望圖上鄰居結點的編碼相似;如果是監督任務,即可根據具體結點的任務標簽計算損失

 arrregate函數的選擇等:

1.在實踐中,每個節點的receptive field設置為固定大小,且使用了均勻采樣方法簡化鄰居選擇過程。

2.作者設計了四種不同的聚集策略,分別是Mean、GCN、LSTM、MaxPooling。

就aggretor是GCN, 圖卷積聚集W(PX),W為參數矩陣,P為鄰接矩陣的對稱歸一化矩陣,X為節點特征矩陣。

 

GraphSage的狀態更新公式如下(下式的aggregator 是GCN, 也可以是maxpooling arregator mean aggregator):

 

\mathbf{h}_{v}^{l+1}=\sigma(\mathbf{W}^{l+1}\cdot aggregate(\mathbf{h}_v^l,\{\mathbf{h}_u^l\}),{\forall}u{\in}ne[v])

 

 

 

 

 

 算法:

第四行表示節點vvv的任意相鄰節點的聚合信息的集合;

$h_{N(v)}^{k}$表示從節點$v$的相鄰節點獲取的信息。

AGGGEGATE表示聚合函數;

第5行,將從相鄰節點獲取的信息和和這個節點自身的信息,進行拼接;

可以發現,對一個新加入的節點a,只需要知道其自身特征和相鄰節點,就可以得到其向量表示。不必重新訓練得到其他所有節點的向量表示。當然也可以選擇重新訓練。但是需要保存所有節點深度為k的表示,

這個算法直觀的想法是,每次迭代,或者說每個深度,節點從相鄰節點獲得信息。隨着迭代次數的增多,節點增量地從圖中的更遠處獲得更多的信息。

graphSAGE並不是使用全部的相鄰節點,而是做了固定尺寸的采樣

既然新增的節點,一定會改變原有節點的表示,那么為什么一定要得到每個節點的一個固定的表示呢?何不直接學習一種節點的表示方法。去學習一個節點的信息是怎么通過其鄰居節點的特征聚合而來的。 學習到了這樣的“聚合函數”,而我們本身就已知各個節點的特征和鄰居關系,我們就可以很方便地得到一個新節點的表示了

GraphSAGE的核心:GraphSAGE不是試圖學習一個圖上所有node的embedding,而是學習一個為每個node產生embedding的映射

訓練相關:

1.為了將算法1擴展到minibatch環境上,給定一組輸入頂點,先采樣采出需要的鄰居集合(直到深度K),然后運行內部循環(算法1的第三行)。
2.出於對計算效率的考慮,對每個頂點采樣一定數量的鄰居頂點作為待聚合信息的頂點。設需要的鄰居數量,即采樣數量為S,若頂點鄰居數少於S,則采用有放回的抽樣方法,直到采樣出SSS個頂點。若頂點鄰居數大於SSS,則采用無放回的抽樣。(即采用有放回的重采樣/負采樣方法達到S)

3.文中在較大的數據集上實驗。因此,統一采樣一個固定大小的鄰域集,以保持每個batch的計算占用空間是固定的(即 graphSAGE並不是使用全部的相鄰節點,而是做了固定size的采樣)。

 

 

 

4.這里需要注意的是,每一層的node的表示都是由上一層生成的,跟本層的其他節點無關,這也是一種基於層的采樣方式。
5.在圖中的“1層”,節點v聚合了“0層”的兩個鄰居的信息,v的鄰居u也是聚合了“0層”的兩個鄰居的信息。到了“2層”,可以看到節點v通過“1層”的節點u,擴展到了“0層”的二階鄰居節點。因此,在聚合時,聚合K次,就可以擴展到K階鄰居。

6.鄰居的定義:那就是如何選擇一個節點的鄰居以及多遠的鄰居。這里作者的做法是設置一個定值,每次選擇鄰居的時候就是從周圍的直接鄰居(一階鄰居)中均勻地采樣固定個數個鄰居。

那我就有一個疑問了?每次都只是從其一階鄰居聚合信息,為何作者說:隨着迭代,可以聚合越來越遠距離的信息呢?
后來我想了想,發現確實是這樣的。雖然在聚合時僅僅聚合了一個節點鄰居的信息,但該節點的鄰居,也聚合了其鄰居的信息,這樣,在下一次聚合時,該節點就會接收到其鄰居的鄰居的信息,也就是聚合到了二階鄰居的信息了。還是拿出我的看家本領——用圖說話:

在GraphSAGE的實踐中,作者發現,K不必取很大的值,當K=2時,效果就灰常好了,也就是只用擴展到2階鄰居即可。至於鄰居的個數,文中提到S1×S2<=500,即兩次擴展的鄰居數之際小於500,大約每次只需要擴展20來個鄰居即可。
這也是合情合理,例如在現實生活中,對你影響最大就是親朋好友,這些屬於一階鄰居,然后可能你偶爾從他們口中聽說一些他們的同事、朋友的一些故事,這些會對你產生一定的影響,這些人就屬於二階鄰居。但是到了三階,可能基本對你不會產生什么影響了,例如你聽你同學說他同學聽說她同學的什么事跡,是不是很繞口,繞口就對了,因為你基本不會聽到這樣的故事,你所接觸到的、聽到的、看到的,基本都在“二階”的范圍之內。
7.沒有這種采樣,單個batch的內存和預期運行時是不可預測的,在最壞的情況下是O(∣V∣)O(|\mathcal{V}|)O(∣V∣)。
8.實驗發現,K不必取很大的值,當K=2時,效果就很好了。至於鄰居的個數,文中提到S1⋅S2≤500,即兩次擴展的鄰居數之際小於500,大約每次只需要擴展20來個鄰居時獲得較高的性能。
論文里說固定長度的隨機游走其實就是隨機選擇了固定數量的鄰居.

聚合函數選取

在圖中頂點的鄰居是無序的,所以希望構造出的聚合函數是對稱的(即也就是對它輸入的各種排列,函數的輸出結果不變),同時具有較高的表達能力。 聚合函數的對稱性(symmetry property)確保了神經網絡模型可以被訓練且可以應用於任意順序的頂點鄰居特征集合上。

Mean aggregator

mean aggregator將目標頂點和鄰居頂點的第k−1層向量拼接起來,然后對向量的每個維度進行求均值的操作,將得到的結果做一次非線性變換產生目標頂點的第k層表示向量。
文中用下面的式子替換算法1中的4行和5行得到GCN的inductive變形:

1. 均值聚合近似等價在transducttive GCN框架[Semi-supervised classification with graph convolutional networks. In ICLR, 2016]中的卷積傳播規則

 

 

 

2.文中稱這個修改后的基於均值的聚合器是convolutional的,這個卷積聚合器和文中的其他聚合器的重要不同在於它沒有算法1中第5行的CONCAT操作——卷積聚合器沒有將頂點前一層的表示

和聚合的鄰居向量.
3.
拼接操作可以看作一個是GraphSAGE算法在不同的搜索深度或層之間的簡單的skip connection[Identity mappings in deep residual networks]的形式,它使得模型獲得了巨大的提升

4.舉個簡單例子,比如一個節點的3個鄰居的embedding分別為[1,2,3,4],[2,3,4,5],[3,4,5,6]按照每一維分別求均值就得到了聚合后的鄰居embedding為[2,3,4,5]

Mean aggregator

Pooling聚合器,它既是對稱的,又是可訓練的。Pooling aggregator 先對目標頂點的鄰居頂點的embedding向量進行一次非線性變換,之后進行一次pooling操作(max pooling or mean pooling),將得到結果與目標頂點的表示向量拼接,最后再經過一次非線性變換得到目標頂點的第k層表示向量.

1.max表示element-wise最大值操作,取每個特征的最大值
2.$\sigma$是非線性激活函數
3.所有相鄰節點的向量共享權重,先經過一個非線性全連接層,然后做max-pooling
4.按維度應用 max/mean pooling,可以捕獲鄰居集上在某一個維度的突出的/綜合的表現。

參數學習

基於圖的損失函數傾向於使得相鄰的頂點有相似的表示,但這會使相互遠離的頂點的表示差異變大:

 

 

1.$Z_{u}$為節點$u$通過GraphSAGE生成的embedding

 

2.節點v是節點u隨機游走到達的鄰居

 

3.$\sigma$是sigmoid

 

4.Q是負樣本的數目

 

5.embedding 之間的相似度通過向量點積計算

文中輸入到損失函數的表示$z_{u}$是從包含一個頂點局部鄰居的特征生成出來的,

而不像之前的那些方法(如DeepWalk),對每個頂點訓練一個獨一無二的embedding,

然后簡單進行一個embedding查找操作得到.

基於圖的有監督損失---------------可使用交叉熵或者范數計算

通過前向傳播得到節點u的embedding $z_{u}$,然后梯度下降(實現使用Adam優化器) 進行反向傳播優化參數

$W^{k}$和聚合函數內的參數。

新節點embedding的生成

這個$W^{k}$就是dynamic embedding的核心,因為保存下來了從節點原始的高維特征生成低維embedding的方式。現在,如果想得到一個點的embedding,只需要輸入節點的特征向量,經過卷積(利用已經訓練好的$W^{k}$以及特定聚合函數聚合neighbor的屬性信息),就產生了節點的embedding。

GraphSAGE的核心:GraphSAGE不是試圖學習一個圖上所有node的embedding,而是學習一個為每個node產生embedding的映射

 

為什么GCN是transductive,為啥要把所有節點放在一起訓練?
不一定要把所有節點放在一起訓練,一個個節點放進去訓練也是可以的。無非是如果想得到所有節點的embedding,那么GCN可以把整個graph丟進去,直接得到embedding,還可以直接進行節點分類、邊的預測等任務。

其實,通過GraphSAGE得到的節點的embedding,在增加了新的節點之后,舊的節點也需要更新,這個是無法避免的,因為,新增加點意味着環境變了,那之前的節點的表示自然也應該有所調整。只不過,對於老節點,可能新增一個節點對其影響微乎其微,所以可以暫且使用原來的embedding,但如果新增了很多,極大地改變的原有的graph結構,那么就只能全部更新一次了。從這個角度去想的話,似乎GraphSAGE也不是什么“神仙”方法,只不過生成新節點embedding的過程,實施起來相比於GCN更加靈活方便了。在學習到了各種的聚合函數之后,其實就不用去計算所有節點的embedding,而是需要去考察哪些節點,就現場去計算,這種方法的遷移能力也很強,在一個graph上學得了節點的聚合方法,到另一個新的類似的graph上就可以直接使用了.

核心思想:

去學習一個節點的信息是怎么通過其鄰居節點的特征聚合而來的。學習到了這樣的“聚合函數”,而我們本身就已知各個節點的特征和鄰居關系,我們就可以很方便地得到一個新節點的表示了。
GCN等transductive的方法,學到的是每個節點的一個唯一確定的embedding;而GraphSAGE方法學到的node embedding,是根據node的鄰居關系的變化而變化的,也就是說,即使是舊的node,如果建立了一些新的link,那么其對應的embedding也會變化,而且也很方便地學到。

 

 

 

 

 

 

 

 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM