圖神經網絡-消息傳遞圖神經網絡


消息傳遞圖神經網絡

一、引言

在開篇中我們已經介紹,為節點生成節點表征是圖計算任務成功的關鍵,我們要采用圖神經網絡實現節點表征學習。在此小節,我們將學習基於神經網絡的生成節點表征的范式——消息傳遞范式消息傳遞范式是一種聚合鄰接節點信息來更新中心節點信息的范式,它將卷積算子推廣到了不規則數據領域,實現了圖與神經網絡的連接。此范式包含三個步驟:(1)鄰接節點信息變換、(2)鄰接節點信息聚合到中心節點、(3)聚合信息變換。因其簡單且強大的特性,它廣泛地被人們所采用。此外,我們還將學習如何基於消息傳遞范式構建圖神經網絡

二、 消息傳遞范式介紹

\(\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F\)表示\((k-1)\)層中節點\(i\)的節點特征,\(\mathbf{e}_{j,i} \in \mathbb{R}^D\) 表示從節點\(j\)到節點\(i\)的邊的特征,消息傳遞圖神經網絡可以描述為

\[\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), \]

其中\(\square\)表示可微分的、具有排列不變性(函數輸出結果與輸入參數的排列無關)的函數。具有排列不變性的函數有,和函數、均值函數和最大值函數。\(\gamma\)\(\phi\)表示可微分的函數,如MLPs(多層感知器)。此處內容來源於CREATING MESSAGE PASSING NETWORKS

神經網絡的生成節點表征的操作可稱為節點嵌入(Node Embedding),節點表征也可以稱為節點嵌入。為了統一此次組隊學習中的表述,我們規定節點嵌入只代指神經網絡生成節點表征的操作

下方圖片展示了基於消息傳遞范式的生成節點表征的過程

  1. 在圖的最右側,B節點的鄰接節點(A,C)的信息傳遞給了B,經過信息變換得到了B的嵌入,C、D節點同。
  2. 在圖的中右側,A節點的鄰接節點(B,C,D)的之前得到的節點嵌入傳遞給了節點A;在圖的中左側,聚合得到的信息經過信息變換得到了A節點新的嵌入。
  3. 重復多次,我們可以得到每一個節點的經過多次信息變換的嵌入。這樣的經過多次信息聚合與變換的節點嵌入就可以作為節點的表征,可以用於節點的分類。

節點嵌入(Node Embedding)

三、Pytorch Geometric中的MessagePassing基類

Pytorch Geometric(PyG)提供了MessagePassing基類,它實現了消息傳播的自動處理,繼承該基類可使我們方便地構造消息傳遞圖神經網絡,我們只需定義函數\(\phi\),即message()函數,和函數\(\gamma\),即update()函數,以及使用的消息聚合方案,即aggr="add"aggr="mean"aggr="max"。這些是在以下方法的幫助下完成的:

  • MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
    • aggr:定義要使用的聚合方案("add"、"mean "或 "max");
    • flow:定義消息傳遞的流向("source_to_target "或 "target_to_source");
    • node_dim:定義沿着哪個軸線傳播。
  • MessagePassing.propagate(edge_index, size=None, **kwargs)
    • 開始傳播消息的起始調用。它以edge_index(邊的端點的索引)和flow(消息的流向)以及一些額外的數據為參數。
    • 請注意,propagate()不僅限於在形狀為[N, N]的對稱鄰接矩陣中交換消息,還可以通過傳遞size=(N, M)作為額外參數。例如,在二部圖的形狀為[N, M]的一般稀疏分配矩陣中交換消息。
    • 如果設置size=None,則假定鄰接矩陣是對稱的。
    • 對於有兩個獨立的節點集合和索引集合的二部圖,並且每個集合都持有自己的信息,我們可以傳遞一個元組參數,即x=(x_N, x_M),來標記信息的區分。
  • MessagePassing.message(...)
    • 首先確定要給節點\(i\)傳遞消息的邊的集合,如果flow="source_to_target",則是\((j,i) \in \mathcal{E}\)的邊的集合;
    • 如果flow="target_to_source",則是\((i,j) \in \mathcal{E}\)的邊的集合。
    • 接着為各條邊創建要傳遞給節點\(i\)的消息,即實現\(\phi\)函數。
    • MessagePassing.message(...)函數接受最初傳遞給MessagePassing.propagate(edge_index, size=None, **kwargs)函數的所有參數。
    • 此外,傳遞給propagate()的張量可以被映射到各自的節點\(i\)\(j\)上,只需在變量名后面加上_i_j。我們把\(i\)稱為消息傳遞的目標中心節點,把\(j\)稱為鄰接節點。
  • MessagePassing.aggregate(...)
    • 將從源節點傳遞過來的消息聚合在目標節點上,一般可選的聚合方式有sum, meanmax
  • MessagePassing.message_and_aggregate(...)
    • 在一些場景里,鄰接節點信息變換和鄰接節點信息聚合這兩項操作可以融合在一起,那么我們可以在此函數里定義這兩項操作,從而讓程序運行更加高效。
  • MessagePassing.update(aggr_out, ...):
    • 為每個節點\(i \in \mathcal{V}\)更新節點表征,即實現\(\gamma\)函數。該函數以聚合函數的輸出為第一個參數,並接收所有傳遞給propagate()函數的參數。

以上內容來源於The “MessagePassing” Base Class

四、繼承MessagePassing類的GCNConv

GCNConv的數學定義為

\[\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), \]

其中,相鄰節點的特征首先通過權重矩陣\(\mathbf{\Theta}\)進行轉換,然后按端點的度進行歸一化處理,最后進行加總。這個公式可以分為以下幾個步驟:

  1. 向鄰接矩陣添加自環邊。
  2. 線性轉換節點特征矩陣。
  3. 計算歸一化系數。
  4. 歸一化\(j\)中的節點特征。
  5. 將相鄰節點特征相加("求和 "聚合)。

步驟1-3通常是在消息傳遞發生之前計算的。步驟4-5可以使用MessagePassing基類輕松處理。該層的全部實現如下所示。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息從源節點傳播到目標節點
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

GCNConv繼承了MessagePassing並以"求和"作為領域節點信息聚合方式。該層的所有邏輯都發生在其forward()方法中。在這里,我們首先使用torch_geometric.utils.add_self_loops()函數向我們的邊索引添加自循環邊(步驟1),以及通過調用torch.nn.Linear實例對節點特征進行線性變換(步驟2)。

歸一化系數是由每個節點的節點度得出的,它被轉換為每個邊的節點度。結果被保存在形狀[num_edges,]的張量norm中(步驟3)。

message()函數中,我們需要通過norm對相鄰節點特征x_j進行歸一化處理。這里,x_j包含每條邊的源節點特征,即每個中心節點的鄰接。

這就是創建一個簡單的x傳遞層的全部內容。我們可以把這個層作為深度架構的構建塊。我們可以很方便地初始化和調用它:

conv = GCNConv(16, 32)
x = conv(x, edge_index)

以上內容來源於Implementing the GCN Layer

五、propagate函數

propagate函數源碼:

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
    r"""開始消息傳播的初始調用。
    Args:
        edge_index (Tensor or SparseTensor): 定義了消息傳播流。
        	當flow="source_to_target"時,節點`edge_index[0]`的信息將被發送到節點`edge_index[1]`,
        	反之當flow="target_to_source"時,節點`edge_index[1]`的信息將被發送到節點`edge_index[0]`
        kwargs: 圖其他屬性或額外的數據。
    """

edge_indexpropagate函數必須的參數。在我們的message函數中希望接受到哪些數據(或圖的屬性或額外的數據),就要在propagate函數的調用中傳遞哪些參數。

六、覆寫message函數

在第四部分例子中,我們覆寫的message函數接收兩個參數x_jnorm,而propagate函數被傳遞三個參數edge_index, x=x, norm=norm。由於xData類的屬性,且message函數接收x_j參數而不是x參數,所以在propagate函數被調用,message函數被執行之前,一項額外的操作被執行,該項操作根據edge_index參數從x中分離出x_j。事實上,在message函數里,當參數是Data類的屬性時,我們可以在參數名后面拼接_i_j來指定要接收源節點的屬性或是目標節點的屬性。類似的,如果我們希望在message函數中額外再接受源節點的度,那么我們做如下的修改(假設節點的度為deg,它是Data對象的屬性):

class GCNConv(MessagePassing):
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i # 這里不管正確性
        

七、覆寫aggregate函數

我們在前面的例子中增加如下的aggregate函數,通過觀察運行結果我們發現,我們覆寫的aggregate函數被調用,同時在super(GCNConv, self).__init__(aggr='add')中傳遞給aggr參數的值被存儲到了self.aggr屬性中。

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)

    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
        

八、覆寫message_and_aggregate函數

在一些例子中,消息傳遞與消息聚合可以融合在一起,這種情況我們通過覆寫message_and_aggregate函數來實現:

from torch_sparse import SparseTensor

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此處傳的不再是edge_idex,而是SparseTensor類型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, d=d)
    
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i # 這里不管正確性
    
    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
    
    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')

運行程序后我們可以看到雖然我們同時覆寫了message函數和aggregate函數,然而只有message_and_aggregate函數被執行。

九、覆寫update函數

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

    def update(self, inputs: Tensor) -> Tensor:
        return inputs

update函數接收聚合的輸出作為第一個參數,並接收傳遞給propagate的任何參數。

十、結語

消息傳遞范式是實現圖神經網絡的一種通用范式。消息傳遞范式遵循“消息傳播->消息聚合->消息更新”這一過程,實現將鄰接節點的信息聚合到中心節點上。在PyG中,MessagePassing是所有基於消息傳遞范式的圖神經網絡的基類。MessagePassing類大大方便了我們圖神經網絡的構建,但由於其高度封裝性,它也向我們隱藏了很多的細節。

通過此篇文章的學習,我們打開了MessagePassing類的黑箱子,介紹了繼承MessagePassing類構造自己的圖神經網絡類的規范。我們再次強調,要掌握如何基於MessagePassing類構建自己的圖神經網絡類,我們不能僅停留於理論理解層面,我們需要通過逐行代碼調試,來觀察代碼運行流程。

作業

  1. 請總結MessagePassing類的運行流程以及繼承MessagePassing類的規范。
  2. 請繼承MessagePassing類來自定義以下的圖神經網絡類,並進行測試:
    1. 第一個類,覆寫message函數,要求該函數接收消息傳遞源節點屬性x、目標節點度d
    2. 第二個類,在第一個類的基礎上,再覆寫aggregate函數,要求不能調用super類的aggregate函數,並且不能直接復制super類的aggregate函數內容。
    3. 第三個類,在第二個類的基礎上,再覆寫update函數,要求對節點信息做一層線性變換。
    4. 第四個類,在第三個類的基礎上,再覆寫message_and_aggregate函數,要求在這一個函數中實現前面message函數和aggregate函數的功能。

參考資料


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM