圖卷積神經網絡GCN:整圖分類(含示例及代碼)


關於整圖分類,有篇知乎寫的很好:【圖分類】10分鍾就學會的圖分類教程,基於pytorch和dgl。下面的代碼也是來者這篇知乎。

import dgl
import torch
from torch._C import device
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import MiniGCDataset
from dgl.nn.pytorch import GraphConv
from sklearn.metrics import accuracy_score


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)  # 定義第一層圖卷積
        self.conv2 = GraphConv(hidden_dim, hidden_dim)  # 定義第二層圖卷積
        self.classify = nn.Linear(hidden_dim, n_classes)   # 定義分類器

    def forward(self, g):
        """g表示批處理后的大圖,N表示大圖的所有節點數量,n表示圖的數量 
        """
        # 為方便,我們用節點的度作為初始節點特征。對於無向圖,入度 = 出度
        h = g.in_degrees().view(-1, 1).float() # [N, 1]
        # 執行圖卷積和激活函數
        h = F.relu(self.conv1(g, h))  # [N, hidden_dim]
        h = F.relu(self.conv2(g, h))  # [N, hidden_dim]
        g.ndata['h'] = h    # 將特征賦予到圖的節點
        # 通過平均池化每個節點的表示得到圖表示
        hg = dgl.mean_nodes(g, 'h')   # [n, hidden_dim]
        return self.classify(hg)  # [n, n_classes]

def collate(samples):
    # 輸入參數samples是一個列表
    # 列表里的每個元素是圖和標簽對,如[(graph1, label1), (graph2, label2), ...]
    # zip(*samples)是解壓操作,解壓為[(graph1, graph2, ...), (label1, label2, ...)]
    graphs, labels = map(list, zip(*samples))
    # dgl.batch 將一批圖看作是具有許多互不連接的組件構成的大型圖
    return dgl.batch(graphs), torch.tensor(labels, dtype=torch.long)


# 創建訓練集和測試集
trainset = MiniGCDataset(2000, 10, 20)  # 生成2000個圖,每個圖的最小節點數>=10, 最大節點數<=20
testset = MiniGCDataset(1000, 10, 20) 

# 用pytorch的DataLoader和之前定義的collect函數
data_loader = DataLoader(trainset, batch_size=64, shuffle=True,
                         collate_fn=collate)

DEVICE = torch.device("cuda:2")
# 構造模型 
model = Classifier(1, 256, trainset.num_classes)
model.to(DEVICE)

# 定義分類交叉熵損失
loss_func = nn.CrossEntropyLoss()
# 定義Adam優化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模型訓練
model.train()
epoch_losses = []
for epoch in range(100): 
    epoch_loss = 0
    for iter, (batchg, label) in enumerate(data_loader):
        batchg, label = batchg.to(DEVICE), label.to(DEVICE)
        prediction = model(batchg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)


# 測試
test_loader = DataLoader(testset, batch_size=64, shuffle=False,
                         collate_fn=collate)
model.eval()
test_pred, test_label = [], []
with torch.no_grad():
    for it, (batchg, label) in enumerate(test_loader):
        batchg, label = batchg.to(DEVICE), label.to(DEVICE)
        pred = torch.softmax(model(batchg), 1)
        pred = torch.max(pred, 1)[1].view(-1)
        test_pred += pred.detach().cpu().numpy().tolist()
        test_label += label.cpu().numpy().tolist()
print("Test accuracy: ", accuracy_score(test_label, test_pred))

  

運行結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM