GCN可以認為由兩步組成:
對於每個節點 $u$
1)匯總鄰居的表示$h_v$ 產生中間表示 $\hat h_u$
2) 使用$W_u$線性投影 $\hat h_v$, 再經過非線性變換 $f$ , 即 $h_u = f(W_u \hat h_u)$
首先定義message函數和reduce函數。
import dgl import dgl.function as fn import torch as th import torch.nn as nn import torch.nn.functional as F from dgl import DGLGraph ## 定義消息函數 和 reduce函數 gcn_msg = fn.copy_src(src='h', out='m') gcn_reduce = fn.sum(msg='m', out='h')
定義GCN
## 定義GCNLayer class GCNLayer(nn.Module): def __init__(self, in_feats, out_feats): super(GCNLayer, self).__init__() self.linear = nn.Linear(in_feats, out_feats) def forward(self, g, feature): # Creating a local scope so that all the stored ndata and edata # (such as the `'h'` ndata below) are automatically popped out # when the scope exits. with g.local_scope(): g.ndata['h'] = feature g.update_all(gcn_msg, gcn_reduce) h = g.ndata['h'] return self.linear(h) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.layer1 = GCNLayer(1433, 16) self.layer2 = GCNLayer(16, 7) def forward(self, g, features): x = F.relu(self.layer1(g, features)) x = self.layer2(g, x) return x net = Net() print(net)