消息傳遞圖神經網絡
一、引言
在開篇中我們已經介紹,為節點生成節點表征是圖計算任務成功的關鍵,我們要采用圖神經網絡實現節點表征學習。在此小節,我們將學習基於神經網絡的生成節點表征的范式——消息傳遞范式。消息傳遞范式是一種聚合鄰接節點信息來更新中心節點信息的范式,它將卷積算子推廣到了不規則數據領域,實現了圖與神經網絡的連接。此范式包含三個步驟:(1)鄰接節點信息變換、(2)鄰接節點信息聚合到中心節點、(3)聚合信息變換。因其簡單且強大的特性,它廣泛地被人們所采用。此外,我們還將學習如何基於消息傳遞范式構建圖神經網絡。
二、 消息傳遞范式介紹
用\(\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F\)表示\((k-1)\)層中節點\(i\)的節點特征,\(\mathbf{e}_{j,i} \in \mathbb{R}^D\) 表示從節點\(j\)到節點\(i\)的邊的特征,消息傳遞圖神經網絡可以描述為
其中\(\square\)表示可微分的、具有排列不變性(函數輸出結果與輸入參數的排列無關)的函數。具有排列不變性的函數有,和函數、均值函數和最大值函數。\(\gamma\)和\(\phi\)表示可微分的函數,如MLPs(多層感知器)。此處內容來源於CREATING MESSAGE PASSING NETWORKS。
神經網絡的生成節點表征的操作可稱為節點嵌入(Node Embedding),節點表征也可以稱為節點嵌入。為了統一此次組隊學習中的表述,我們規定節點嵌入只代指神經網絡生成節點表征的操作。
下方圖片展示了基於消息傳遞范式的生成節點表征的過程:
- 在圖的最右側,B節點的鄰接節點(A,C)的信息傳遞給了B,經過信息變換得到了B的嵌入,C、D節點同。
- 在圖的中右側,A節點的鄰接節點(B,C,D)的之前得到的節點嵌入傳遞給了節點A;在圖的中左側,聚合得到的信息經過信息變換得到了A節點新的嵌入。
- 重復多次,我們可以得到每一個節點的經過多次信息變換的嵌入。這樣的經過多次信息聚合與變換的節點嵌入就可以作為節點的表征,可以用於節點的分類。
三、Pytorch Geometric中的MessagePassing
基類
Pytorch Geometric(PyG)提供了MessagePassing
基類,它實現了消息傳播的自動處理,繼承該基類可使我們方便地構造消息傳遞圖神經網絡,我們只需定義函數\(\phi\),即message()
函數,和函數\(\gamma\),即update()
函數,以及使用的消息聚合方案,即aggr="add"
、aggr="mean"
或aggr="max"
。這些是在以下方法的幫助下完成的:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
:aggr
:定義要使用的聚合方案("add"、"mean "或 "max");flow
:定義消息傳遞的流向("source_to_target "或 "target_to_source");node_dim
:定義沿着哪個軸線傳播。
MessagePassing.propagate(edge_index, size=None, **kwargs)
:- 開始傳播消息的起始調用。它以
edge_index
(邊的端點的索引)和flow
(消息的流向)以及一些額外的數據為參數。 - 請注意,
propagate()
不僅限於在形狀為[N, N]
的對稱鄰接矩陣中交換消息,還可以通過傳遞size=(N, M)
作為額外參數。例如,在二部圖的形狀為[N, M]
的一般稀疏分配矩陣中交換消息。 - 如果設置
size=None
,則假定鄰接矩陣是對稱的。 - 對於有兩個獨立的節點集合和索引集合的二部圖,並且每個集合都持有自己的信息,我們可以傳遞一個元組參數,即
x=(x_N, x_M)
,來標記信息的區分。
- 開始傳播消息的起始調用。它以
MessagePassing.message(...)
:- 首先確定要給節點\(i\)傳遞消息的邊的集合,如果
flow="source_to_target"
,則是\((j,i) \in \mathcal{E}\)的邊的集合; - 如果
flow="target_to_source"
,則是\((i,j) \in \mathcal{E}\)的邊的集合。 - 接着為各條邊創建要傳遞給節點\(i\)的消息,即實現\(\phi\)函數。
MessagePassing.message(...)
函數接受最初傳遞給MessagePassing.propagate(edge_index, size=None, **kwargs)
函數的所有參數。- 此外,傳遞給
propagate()
的張量可以被映射到各自的節點\(i\)和\(j\)上,只需在變量名后面加上_i
或_j
。我們把\(i\)稱為消息傳遞的目標中心節點,把\(j\)稱為鄰接節點。
- 首先確定要給節點\(i\)傳遞消息的邊的集合,如果
MessagePassing.aggregate(...)
:- 將從源節點傳遞過來的消息聚合在目標節點上,一般可選的聚合方式有
sum
,mean
和max
。
- 將從源節點傳遞過來的消息聚合在目標節點上,一般可選的聚合方式有
MessagePassing.message_and_aggregate(...)
:- 在一些場景里,鄰接節點信息變換和鄰接節點信息聚合這兩項操作可以融合在一起,那么我們可以在此函數里定義這兩項操作,從而讓程序運行更加高效。
MessagePassing.update(aggr_out, ...)
:- 為每個節點\(i \in \mathcal{V}\)更新節點表征,即實現\(\gamma\)函數。該函數以聚合函數的輸出為第一個參數,並接收所有傳遞給
propagate()
函數的參數。
- 為每個節點\(i \in \mathcal{V}\)更新節點表征,即實現\(\gamma\)函數。該函數以聚合函數的輸出為第一個參數,並接收所有傳遞給
以上內容來源於The “MessagePassing” Base Class。
四、繼承MessagePassing
類的GCNConv
GCNConv的數學定義為
其中,相鄰節點的特征首先通過權重矩陣\(\mathbf{\Theta}\)進行轉換,然后按端點的度進行歸一化處理,最后進行加總。這個公式可以分為以下幾個步驟:
- 向鄰接矩陣添加自環邊。
- 線性轉換節點特征矩陣。
- 計算歸一化系數。
- 歸一化\(j\)中的節點特征。
- 將相鄰節點特征相加("求和 "聚合)。
步驟1-3通常是在消息傳遞發生之前計算的。步驟4-5可以使用MessagePassing
基類輕松處理。該層的全部實現如下所示。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息從源節點傳播到目標節點
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
GCNConv
繼承了MessagePassing
並以"求和"作為領域節點信息聚合方式。該層的所有邏輯都發生在其forward()
方法中。在這里,我們首先使用torch_geometric.utils.add_self_loops()
函數向我們的邊索引添加自循環邊(步驟1),以及通過調用torch.nn.Linear
實例對節點特征進行線性變換(步驟2)。
歸一化系數是由每個節點的節點度得出的,它被轉換為每個邊的節點度。結果被保存在形狀[num_edges,]
的張量norm
中(步驟3)。
在message()
函數中,我們需要通過norm
對相鄰節點特征x_j
進行歸一化處理。這里,x_j
包含每條邊的源節點特征,即每個中心節點的鄰接。
這就是創建一個簡單的x傳遞層的全部內容。我們可以把這個層作為深度架構的構建塊。我們可以很方便地初始化和調用它:
conv = GCNConv(16, 32)
x = conv(x, edge_index)
以上內容來源於Implementing the GCN Layer。
五、propagate
函數
propagate
函數源碼:
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
r"""開始消息傳播的初始調用。
Args:
edge_index (Tensor or SparseTensor): 定義了消息傳播流。
當flow="source_to_target"時,節點`edge_index[0]`的信息將被發送到節點`edge_index[1]`,
反之當flow="target_to_source"時,節點`edge_index[1]`的信息將被發送到節點`edge_index[0]`
kwargs: 圖其他屬性或額外的數據。
"""
edge_index
是propagate
函數必須的參數。在我們的message
函數中希望接受到哪些數據(或圖的屬性或額外的數據),就要在propagate
函數的調用中傳遞哪些參數。
六、覆寫message
函數
在第四部分例子中,我們覆寫的message
函數接收兩個參數x_j
和norm
,而propagate
函數被傳遞三個參數edge_index, x=x, norm=norm
。由於x
是Data
類的屬性,且message
函數接收x_j
參數而不是x
參數,所以在propagate
函數被調用,message
函數被執行之前,一項額外的操作被執行,該項操作根據edge_index
參數從x
中分離出x_j
。事實上,在message
函數里,當參數是Data
類的屬性時,我們可以在參數名后面拼接_i
或_j
來指定要接收源節點的屬性或是目標節點的屬性。類似的,如果我們希望在message
函數中額外再接受源節點的度,那么我們做如下的修改(假設節點的度為deg
,它是Data
對象的屬性):
class GCNConv(MessagePassing):
def forward(self, x, edge_index):
# ....
return self.propagate(edge_index, x=x, norm=norm, d=d)
def message(self, x_j, norm, d_i):
# x_j has shape [E, out_channels]
return norm.view(-1, 1) * x_j * d_i # 這里不管正確性
七、覆寫aggregate
函數
我們在前面的例子中增加如下的aggregate
函數,通過觀察運行結果我們發現,我們覆寫的aggregate
函數被調用,同時在super(GCNConv, self).__init__(aggr='add')
中傳遞給aggr
參數的值被存儲到了self.aggr
屬性中。
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def forward(self, x, edge_index):
# ....
return self.propagate(edge_index, x=x, norm=norm, d=d)
def aggregate(self, inputs, index, ptr, dim_size):
print(self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
八、覆寫message_and_aggregate
函數
在一些例子中,消息傳遞與消息聚合可以融合在一起,這種情況我們通過覆寫message_and_aggregate
函數來實現:
from torch_sparse import SparseTensor
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def forward(self, x, edge_index):
# ....
adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
# 此處傳的不再是edge_idex,而是SparseTensor類型的Adjancency Matrix
return self.propagate(adjmat, x=x, norm=norm, d=d)
def message(self, x_j, norm, d_i):
# x_j has shape [E, out_channels]
return norm.view(-1, 1) * x_j * d_i # 這里不管正確性
def aggregate(self, inputs, index, ptr, dim_size):
print(self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
def message_and_aggregate(self, adj_t, x, norm):
print('`message_and_aggregate` is called')
運行程序后我們可以看到雖然我們同時覆寫了message
函數和aggregate
函數,然而只有message_and_aggregate
函數被執行。
九、覆寫update
函數
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def update(self, inputs: Tensor) -> Tensor:
return inputs
update
函數接收聚合的輸出作為第一個參數,並接收傳遞給propagate
的任何參數。
十、結語
消息傳遞范式是實現圖神經網絡的一種通用范式。消息傳遞范式遵循“消息傳播->消息聚合->消息更新”這一過程,實現將鄰接節點的信息聚合到中心節點上。在PyG中,MessagePassing
是所有基於消息傳遞范式的圖神經網絡的基類。MessagePassing
類大大方便了我們圖神經網絡的構建,但由於其高度封裝性,它也向我們隱藏了很多的細節。
通過此篇文章的學習,我們打開了MessagePassing
類的黑箱子,介紹了繼承MessagePassing
類構造自己的圖神經網絡類的規范。我們再次強調,要掌握如何基於MessagePassing
類構建自己的圖神經網絡類,我們不能僅停留於理論理解層面,我們需要通過逐行代碼調試,來觀察代碼運行流程。
作業
- 請總結
MessagePassing
類的運行流程以及繼承MessagePassing
類的規范。 - 請繼承
MessagePassing
類來自定義以下的圖神經網絡類,並進行測試:- 第一個類,覆寫
message
函數,要求該函數接收消息傳遞源節點屬性x
、目標節點度d
。 - 第二個類,在第一個類的基礎上,再覆寫
aggregate
函數,要求不能調用super
類的aggregate
函數,並且不能直接復制super
類的aggregate
函數內容。 - 第三個類,在第二個類的基礎上,再覆寫
update
函數,要求對節點信息做一層線性變換。 - 第四個類,在第三個類的基礎上,再覆寫
message_and_aggregate
函數,要求在這一個函數中實現前面message
函數和aggregate
函數的功能。
- 第一個類,覆寫