DGL學習(三): 消息傳遞教程


在本節中,我們將不同級別的消息傳遞API與PageRank一起使用。 在DGL中,消息傳遞和功能轉換是用戶定義的函數(UDF)。

 

PageRank 算法:

在PageRank的每次迭代中,每個節點(網頁)首先將其PageRank值均勻地分散到其下游節點。 每個節點的新PageRank值是通過匯總從其鄰居收到的PageRank值來計算的,然后通過阻尼因子(damping factor)進行調整:

 生成一個隨機圖, 兩點之間有邊的概率為 P:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl

N = 100
P = 0.1
DAMP = 0.8
g = nx.erdos_renyi_graph(N, P) g = dgl.DGLGraph(g)
src = list(range(1,51));dst = [0]*50 # 使用list批量添加
g.add_edges(src, dst)
print(g.number_of_edges()) print(g.number_of_nodes()) nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])
plt.show() 

 

 

 

在pagerank 中, 初始化每個節點初始值為 1/N, 將節點的出度作為節點的特征。

## pv 算法初始值
g.ndata['pv'] = torch.ones(N) / N
g.ndata['deg'] = g.out_degrees(g.nodes()).float()

定義消息函數,該函數將每個節點的PageRank值除以其出度,然后將結果作為消息傳遞給其鄰居。

在DGL中,消息函數是針對邊的,表示為Edge UDF。 Edge UDF接受單個參數edges。 它具有三個成員src,dst和data,用於訪問源節點特征,目標節點特征和邊特征。實現pv算法僅需從src中取特征。

def pagerank_message_func(edges):
    return {'pv': edges.src['pv'] / edges.src['deg']}

定義reduce函數,該函數從其mailbox中聚合消息和刪除消息,並計算其新的PageRank值。

reduce函數是針對節點的,表示為 Node UDF。 Node UDF接受單個參數nodes,nodes具有兩個成員mailbox和data。 data包含節點特征,mailbox包含所有傳入消息特征,這些功能沿第二維堆疊(dim = 1參數)。

可以結合下圖進行理解:

 

 

def pagerank_reduce_func(nodes):
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {'pv' : pv}

注冊消息函數和規約函數, 之后DGL調用它。 pagerank_naive是page_rank的簡單實現。

# 注冊消息函數和歸約函數,稍后DGL將調用它。
g.register_message_func(pagerank_message_func)
g.register_reduce_func(pagerank_reduce_func)

def pagerank_naive(g):
    # Phase #1: send out messages along all edges.
    for u, v in zip(*g.edges()):
        g.send((u, v))
    # Phase #2: receive messages to compute new PageRank values.
    for v in g.nodes():
        g.recv(v)

# 迭代10輪
for k in range(10):
    pagerank_naive(g)

print(g.ndata['pv'])
tensor([0.0446, 0.0107, 0.0087, 0.0102, 0.0085, 0.0130, 0.0091, 0.0059, 0.0079,
        0.0088, 0.0082, 0.0087, 0.0098, 0.0087, 0.0100, 0.0092, 0.0065, 0.0168,
        0.0064, 0.0106, 0.0098, 0.0117, 0.0077, 0.0113, 0.0111, 0.0100, 0.0077,
        0.0051, 0.0084, 0.0070, 0.0048, 0.0163, 0.0102, 0.0084, 0.0098, 0.0127,
        0.0101, 0.0091, 0.0091, 0.0083, 0.0088, 0.0095, 0.0132, 0.0106, 0.0057,
        0.0099, 0.0068, 0.0106, 0.0098, 0.0068, 0.0140, 0.0087, 0.0083, 0.0120,
        0.0107, 0.0109, 0.0072, 0.0090, 0.0069, 0.0124, 0.0094, 0.0106, 0.0071,
        0.0093, 0.0070, 0.0059, 0.0068, 0.0162, 0.0082, 0.0129, 0.0063, 0.0134,
        0.0116, 0.0095, 0.0107, 0.0147, 0.0085, 0.0099, 0.0084, 0.0069, 0.0112,
        0.0120, 0.0076, 0.0105, 0.0125, 0.0091, 0.0063, 0.0085, 0.0051, 0.0102,
        0.0116, 0.0070, 0.0120, 0.0094, 0.0156, 0.0159, 0.0096, 0.0125, 0.0065,
        0.0107])
View Code

 

大圖的批處理語義

上圖中的方法需要遍歷所有節點,不適合於大圖,DGL通過允許在一個batch的節點或邊上進行計算來解決此問題。 例如,以下代碼一次性觸發所有多個節點的消息函數和規約函數。

def pagerank_batch(g):
    g.send(g.edges())
    g.recv(g.nodes())
for k in range(10):
    #pagerank_naive(g)
    pagerank_batch(g)
print(g.ndata['pv'])

並行性方面:  由於每個節點接受的輸出參數是不同的,不同長度的張量沒法進行stack。所以DGL按傳入消息的數量對節點進行分組,分組調用reduce函數來解決該問題。

 

使用更高級別的API來提高效率

def pagerank_level2(g):
    g.update_all()

 

使用內置API

一些常用的消息函數和規約函數DGL都包含了,直接調用即可。

import dgl.function as fn

def pagerank_builtin(g):
    g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg']
    g.update_all(message_func=fn.copy_src(src='pv', out='m'),
                 reduce_func=fn.sum(msg='m',out='m_sum'))
    g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum']

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM