基於圖神經網絡的圖表征學習方法
引言
在此篇文章中我們將學習基於圖神經網絡的圖表征學習方法,圖表征學習要求在輸入節點屬性、邊(和邊的屬性如果有的話)得到一個向量作為圖的表征,基於圖表征進一步的我們可以做圖的預測。基於圖同構網絡(Graph Isomorphism Network, GIN)的圖表征網絡是當前最經典的圖表征學習網絡,我們將以它為例,通過該網絡的實現、項目實踐和理論分析,三個層面來學習基於圖神經網絡的圖表征學習方法。
提出圖同構網絡的論文:How Powerful are Graph Neural Networks?
基於圖同構網絡(GIN)的圖表征網絡的實現
基於圖同構網絡的圖表征學習主要包含以下兩個過程:
- 首先計算得到節點表征;
- 其次對圖上各個節點的表征做圖池化(Graph Pooling),或稱為圖讀出(Graph Readout),得到圖的表征(Graph Representation)。
在這里,我們將采用自頂向下的方式,來學習基於圖同構模型(GIN)的圖表征學習方法。我們首先關注如何基於節點表征計算得到圖的表征,而忽略計算結點表征的方法。
基於圖同構網絡的圖表征模塊(GINGraphRepr Module)
此模塊首先采用GINNodeEmbedding
模塊對圖上每一個節點做節點嵌入(Node Embedding),得到節點表征,然后對節點表征做圖池化得到圖的表征,最后用一層線性變換得到圖的表征(graph representation)。代碼實現如下:
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding
class GINGraphRepr(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
"""GIN Graph Pooling Module
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了圖表征的維度,dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): 可選的值為"last"和"sum"。選"last",只取最后一層的結點的嵌入,選"sum"對各層的結點的嵌入求和。Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. 可選的值為"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, relu is applied to output to ensure positivity
# 因為預測目標的取值范圍就在 (0, 50] 內
return torch.clamp(output, min=0, max=50)
可以看到可選的基於結點表征計算得到圖表征的方法有:
- "sum":
- 對節點表征求和;
- 使用模塊
torch_geometric.nn.glob.global_add_pool
。
- "mean":
- 對節點表征求平均;
- 使用模塊
torch_geometric.nn.glob.global_mean_pool
。
- "max":取節點表征的最大值。
- 對一個batch中所有節點計算節點表征各個維度的最大值;
- 使用模塊
torch_geometric.nn.glob.global_max_pool
。
- "attention":
- 基於Attention對節點表征加權求和;
- 使用模塊 torch_geometric.nn.glob.GlobalAttention;
- 來自論文 “Gated Graph Sequence Neural Networks” 。
- "set2set":
- 另一種基於Attention對節點表征加權求和的方法;
- 使用模塊 torch_geometric.nn.glob.Set2Set;
- 來自論文 “Order Matters: Sequence to sequence for sets”。
PyG中集成的所有的圖池化的方法可見於Global Pooling Layers。
接下來我們將學習節點嵌入的方法。
基於圖同構網絡的節點嵌入模塊(GINNodeEmbedding Module)
此模塊基於多層GINConv
實現結點嵌入的計算。此處我們先忽略GINConv
的實現。此模塊得到的節點屬性輸入為類別型向量,我們首先用AtomEncoder
對其做嵌入得到第0
層節點表征(稍后我們再對AtomEncoder
做分析)。然后我們逐層計算節點表征,從第1
層開始到第num_layers
層,每一層節點表征的計算都以上一層的節點表征h_list[layer]
、邊edge_index
和邊的屬性edge_attr
為輸入。需要注意的是,GINConv
的層數越多,此模塊的感受野(receptive field)越大,結點i
的表征最遠能捕獲到結點i
的距離為num_layers
的鄰接節點的信息。
import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F
# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先將類別型原子屬性轉化為原子表征
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
接下來我們來學習圖同構網絡的關鍵組件GINConv
。
GINConv
--圖同構卷積層
圖同構卷積層的數學定義如下:
PyG中已經實現了此模塊,我們可以通過torch_geometric.nn.GINConv
來使用PyG定義好的圖同構卷積層,然而該實現不支持存在邊屬性的圖。在這里我們自己自定義一個支持邊屬性的GINConv
模塊。
由於輸入的邊屬性為類別型,因此我們需要先將類別型邊屬性轉換為邊表征。我們定義的GINConv
模塊遵循“消息傳遞、消息聚合、消息更新”這一過程。
- 這一過程隨着
self.propagate
的調用開始執行,該函數接收edge_index
,x
,edge_attr
此三個函數。edge_index
是形狀為2,num_edges
的張量(tensor)。 - 在消息傳遞過程中,此張量首先被按行拆分為
x_i
和x_j
張量,x_j
表示了消息傳遞的源節點,x_i
表示了消息傳遞的目標節點。 - 接着
message
函數被調用,此函數定義了從源節點傳入到目標節點的消息,在這里要傳遞的消息是源節點表征與邊表征之和的relu
。我們在super(GINConv, self).__init__(aggr = "add")
中定義了消息聚合方式為add
,那么傳入給任一個目標節點的所有消息被求和得到aggr_out
,它是目標節點的中間過程的信息。 - 接着執行消息更新過程,我們的類
GINConv
繼承了MessagePassing
類,因此update
函數被調用。然而我們希望對節點做消息更新中加入目標節點自身的消息,因此在update
函數中我們只簡單返回輸入的aggr_out
。 - 然后在
forward
函數中我們執行out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
實現消息的更新。
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先將類別型邊屬性轉換為邊表征
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
理論分析
動機(Motivation)
論文原話:“The design of new GNNs is mostly based on empirical intuition, heuristics, and experimental trial-and-error. There is little understanding of the properties and limitations of GNNs, and formal analysis of GNNs' representational capacity is limited.” 新的圖神經網絡的設計大多基於經驗性的直覺、啟發式的方法和實驗性的試錯。人們對圖神經網絡的特性和局限性了解甚少,對圖神經網絡的表征能力學習的正式分析也很有限。
貢獻與結論
- (理論上)圖神經網絡在區分圖結構方面最高只能達到與WL Test一樣的能力。
- 確定了鄰接點聚合方法和圖池化方法的應具備的條件,在這些條件下,所產生的圖神經網絡能達到與WL Test一樣的能力。
- 確定了過去流行的圖神經網絡變體(如GCN和GraphSAGE)無法區分的圖結構,並描述了這種基於圖神經網絡的模型能夠捕獲的圖結構類型。
- 開發了一個簡單的神經結構--圖形同構網絡(GIN),並證明其分辨/表示能力與WL Test相當。
背景:Weisfeiler-Lehman Test (WL Test)
圖同構性測試
兩個圖是同構的,意思是兩個圖擁有一樣的拓撲結構,也就是說,我們可以通過重新標記節點從一個圖中得到另外一個圖。Weisfeiler-Lehman 圖的同構性測試算法,簡稱WL Test,是一種用於測試兩個圖是否同構的算法。
WL Test 的一維形式,類似於圖神經網絡中的鄰接節點聚合。WL Test 1)迭代地聚合節點及其鄰接節點的標簽,然后2)將聚合的標簽散列成唯一的新標簽,該過程形式化為下方的公示。在迭代過程中,發現兩個圖之間的節點的標簽不同時,就可以確定這兩個圖是非同構的。需要注意的是節點標簽可能的取值只能是有限個數。
在上方的公示中,\(L^{h}_{u}\)表示節點\(u\)的第\(h\)次迭代的標簽,第\(0\)次迭代的標簽為節點原始標簽。
WL測試不能保證對所有圖都有效,特別是對於具有高度對稱性的圖,如鏈式圖、完全圖、環圖和星圖,它會判斷錯誤。
Weisfeiler-Lehman Graph Kernels 方法提出用WL子樹核衡量圖之間相似性。該方法使用WL Test不同迭代中的節點標簽計數作為圖的表征向量,它具有與WL Test相同的判別能力。直觀地說,在WL Test的第\(k\)次迭代中,一個節點的標簽代表了以該節點為根的高度為\(k\)的子樹結構。
Weisfeiler-Leman Test 算法舉例說明:給定兩個圖\(G\)和\(G^{\prime}\),每個節點擁有標簽(實際中,一些圖沒有節點標簽,我們可以以節點的度作為標簽)。
Weisfeiler-Leman Test 算法通過重復執行以下給節點打標簽的過程來實現圖是否同構的判斷:
- 聚合自身與鄰接節點的標簽得到一串字符串,自身標簽與鄰接節點的標簽中間用
,
分隔,鄰接節點的標簽按升序排序。排序的原因在於要保證單射性,即保證獲得的結果不因鄰接節點的順序改變而改變。
- 標簽散列,即標簽壓縮,將較長的字符串映射到一個簡短的標簽。
- 給節點重新打上標簽。
每重復一次以上的過程,就完成一次節點自身標簽與鄰接節點標簽的聚合。
當出現兩個圖相同節點標簽的出現次數不一致時,即可判斷兩個圖不相似。如果上述的步驟重復一定的次數后,沒有發現有相同節點標簽的出現次數不一致的情況,那么我們無法判斷兩個圖是否同構。
當兩個節點的\(h\)層的標簽一樣時,表示分別以這兩個節點為根節點的WL子樹是一致的。WL子樹與普通子樹不同,WL子樹包含重復的節點。下圖展示了一棵以1節點為根節點高為2的WL子樹。
圖相似性評估
此方法來自於Weisfeiler-Lehman Graph Kernels。
WL Test 算法的一點局限性是,它只能判斷兩個圖的相似性,無法衡量圖之間的相似性。要衡量兩個圖的相似性,我們用WL Subtree Kernel方法。該方法的思想是用WL Test算法得到節點的多層的標簽,然后我們可以分別統計圖中各類標簽出現的次數,存於一個向量,這個向量可以作為圖的表征。兩個圖的這樣的向量的內積,即可作為這兩個圖的相似性的估計。

圖同構網絡模型的構建
能實現判斷圖同構性的圖神經網絡需要滿足,只在兩個節點自身標簽一樣且它們的鄰接節點一樣時,圖神經網絡將這兩個節點映射到相同的表征,即映射是單射性的。可重復集合(Multisets)指的是元素可重復的集合,元素在集合中沒有順序關系。 一個節點的所有鄰接節點是一個可重復集合,一個節點可以有重復的鄰接節點,鄰接節點沒有順序關系。因此GIN模型中生成節點表征的方法遵循WL Test算法更新節點標簽的過程。
在生成節點的表征后仍需要執行圖池化(或稱為圖讀出)操作得到圖表征,最簡單的圖讀出操作是做求和。由於每一層的節點表征都可能是重要的,因此在圖同構網絡中,不同層的節點表征在求和后被拼接,其數學定義如下,
采用拼接而不是相加的原因在於不同層節點的表征屬於不同的特征空間。未做嚴格的證明,這樣得到的圖的表示與WL Subtree Kernel得到的圖的表征是等價的。
結語
在此篇文章中,我們學習了基於圖同構網絡(GIN)的圖表征網絡,為了得到圖表征首先需要做節點表征,然后做圖讀出。GIN中節點表征的計算遵循WL Test算法中節點標簽的更新方法,因此它的上界是WL Test算法。在圖讀出中,我們對所有的節點表征(加權,如果用Attention的話)求和,這會造成節點分布信息的丟失。
作業
- 請畫出下方圖片中的6號、3號和5號節點的從1層到3層到WL子樹。
參考資料
-
提出GlobalAttention的論文: “Gated Graph Sequence Neural Networks”
-
提出Set2Set的論文:“Order Matters: Sequence to sequence for sets”
-
PyG中集成的所有的圖池化的方法:Global Pooling Layers
-
Weisfeiler-Lehman Test: Brendan L Douglas. The weisfeiler-lehman method and graph isomorphism testing. arXiv preprint arXiv:1101.5211, 2011.