圖融合GCN(Graph Convolutional Networks)


圖融合GCN(Graph Convolutional Networks)

數據其實是圖(graph),圖在生活中無處不在,如社交網絡,知識圖譜,蛋白質結構等。本文介紹GNN(Graph Neural Networks)中的分支:GCN(Graph Convolutional Networks)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 GCN的PyTorch實現

雖然GCN從數學上較難理解,但是,實現是非常簡單的,值得注意的一點是,一般情況下鄰接矩陣是稀疏矩陣,所以,在實現矩陣乘法時,采用稀疏運算會更高效。首先,圖卷積層的實現:

import torch
import torch.nn as nn


class GraphConvolution(nn.Module):
"""GCN layer"""

def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight)
if self.bias isnotNone:
nn.init.zeros_(self.bias)

def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias isnotNone:
return output + self.bias
else:
return output

def extra_repr(self):
return'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias isnotNone
)
對於GCN,只需要將圖卷積層堆積起來就可以,這里,實現一個兩層的GCN:
class GCN(nn.Module):
"""a simple two layer GCN"""
def __init__(self, nfeat, nhid, nclass):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)

def forward(self, input, adj):
h1 = F.relu(self.gc1(input, adj))
logits = self.gc2(h1, adj)
return logits

這里的激活函數采用ReLU,后面,將用這個網絡實現一個圖中節點的半監督分類任務。

數據的提取,只需要load就可以:

# https://github.com/tkipf/pygcn/blob/master/pygcn/utils.py
adj, features, labels, idx_train, idx_val, idx_test = load_data(path="./data/cora/")

值得注意的有兩點,一是論文引用應該是單向圖,但是在網絡時,要先將其轉成無向圖,或者說建立雙向引用,這個對模型訓練結果影響較大:

# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

另外,官方實現中對鄰接矩陣采用的是普通均值歸一化,當然,也可以采用對稱歸一化方式:

def normalize_adj(adj):
    """compute L=D^-0.5 * (A+I) * D^-0.5"""
    adj += sp.eye(adj.shape[0])
    degree = np.array(adj.sum(1))
    d_hat = sp.diags(np.power(degree, -0.5).flatten())
    norm_adj = d_hat.dot(adj).dot(d_hat)
    return norm_adj

這里,只采用圖中140個有標簽樣本對GCN進行訓練,每個epoch計算出這些節點特征,然后計算loss:

    loss_history = []
    val_acc_history = []
    for epoch in range(epochs):
        model.train()
        logits = model(features, adj)
        loss = criterion(logits[idx_train], labels[idx_train])
       
        train_acc = accuracy(logits[idx_train], labels[idx_train])
       
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
       
        val_acc = test(idx_val)
        loss_history.append(loss.item())
        val_acc_history.append(val_acc.item())
        print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(
            epoch, loss.item(), train_acc.item(), val_acc.item()))

只需要訓練200個epoch,就可以在測試集上達到80%左右的分類准確,GCN的強大可想而知:

 

 融合BN和Conv層

在PyTorch中實現這個融合操作:nn.Conv2d參數:

  • filter weights,W: conv.weight;
  • bias,b: conv.bias;

nn.BatchNorm2d參數:

 

具體的實現代碼如下(Google Colab, https://colab.research.google.com/drive/1mRyq_LlJW4u_rArzzhEe_T6tmEWoNN1K):

import torch
    import torchvision
   
    def fuse(conv, bn):
   
        fused = torch.nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            bias=True
        )
   
        # setting weights
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
        fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
       
        # setting bias
        if conv.bias isnotNone:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros( conv.weight.size(0) )
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                              torch.sqrt(bn.running_var + bn.eps)
                            )
        fused.bias.copy_( b_conv + b_bn )
   
        return fused
   
    # Testing
    # we need to turn off gradient calculation because we didn't write it
    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    resnet18 = torchvision.models.resnet18(pretrained=True)
    # removing all learning variables, etc
    resnet18.eval()
    model = torch.nn.Sequential(
        resnet18.conv1,
        resnet18.bn1
    )
    f1 = model.forward(x)
    fused = fuse(model[0], model[1])
    f2 = fused.forward(x)
    d = (f1 - f2).mean().item()
    print("error:",d)

 

 

參考鏈接:

    1. Semi-Supervised Classification with Graph Convolutional Networks https://arxiv.org/abs/1609.02907
    2. How to do Deep Learning on Graphs with Graph Convolutional Networks https://towardsdatascience.com/how-to-do-deep-learning-on-graphs-with-graph-convolutional-networks-7d2250723780
    3. Graph Convolutional Networks http://tkipf.github.io/graph-convolutional-networks
    4. Graph Convolutional Networks in PyTorch https://github.com/tkipf/pygcn
    5. 回顧頻譜圖卷積的經典工作:從ChebNet到GCN https://www.jianshu.com/p/2fd5a2454781
    6. 圖數據集之cora數據集介紹- 用pyton處理 - 可用於GCN任務 https://blog.csdn.net/yeziand01/article/details/93374216
    7. Speeding up model with fusing batch normalization and convolution (http://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM