關於整圖分類,有篇知乎寫的很好:【圖分類】10分鍾就學會的圖分類教程,基於pytorch和dgl。下面的代碼也是來者這篇知乎。
import dgl import torch from torch._C import device import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from dgl.data import MiniGCDataset from dgl.nn.pytorch import GraphConv from sklearn.metrics import accuracy_score class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(Classifier, self).__init__() self.conv1 = GraphConv(in_dim, hidden_dim) # 定義第一層圖卷積 self.conv2 = GraphConv(hidden_dim, hidden_dim) # 定義第二層圖卷積 self.classify = nn.Linear(hidden_dim, n_classes) # 定義分類器 def forward(self, g): """g表示批處理后的大圖,N表示大圖的所有節點數量,n表示圖的數量 """ # 為方便,我們用節點的度作為初始節點特征。對於無向圖,入度 = 出度 h = g.in_degrees().view(-1, 1).float() # [N, 1] # 執行圖卷積和激活函數 h = F.relu(self.conv1(g, h)) # [N, hidden_dim] h = F.relu(self.conv2(g, h)) # [N, hidden_dim] g.ndata['h'] = h # 將特征賦予到圖的節點 # 通過平均池化每個節點的表示得到圖表示 hg = dgl.mean_nodes(g, 'h') # [n, hidden_dim] return self.classify(hg) # [n, n_classes] def collate(samples): # 輸入參數samples是一個列表 # 列表里的每個元素是圖和標簽對,如[(graph1, label1), (graph2, label2), ...] # zip(*samples)是解壓操作,解壓為[(graph1, graph2, ...), (label1, label2, ...)] graphs, labels = map(list, zip(*samples)) # dgl.batch 將一批圖看作是具有許多互不連接的組件構成的大型圖 return dgl.batch(graphs), torch.tensor(labels, dtype=torch.long) # 創建訓練集和測試集 trainset = MiniGCDataset(2000, 10, 20) # 生成2000個圖,每個圖的最小節點數>=10, 最大節點數<=20 testset = MiniGCDataset(1000, 10, 20) # 用pytorch的DataLoader和之前定義的collect函數 data_loader = DataLoader(trainset, batch_size=64, shuffle=True, collate_fn=collate) DEVICE = torch.device("cuda:2") # 構造模型 model = Classifier(1, 256, trainset.num_classes) model.to(DEVICE) # 定義分類交叉熵損失 loss_func = nn.CrossEntropyLoss() # 定義Adam優化器 optimizer = optim.Adam(model.parameters(), lr=0.001) # 模型訓練 model.train() epoch_losses = [] for epoch in range(100): epoch_loss = 0 for iter, (batchg, label) in enumerate(data_loader): batchg, label = batchg.to(DEVICE), label.to(DEVICE) prediction = model(batchg) loss = loss_func(prediction, label) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.detach().item() epoch_loss /= (iter + 1) print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss)) epoch_losses.append(epoch_loss) # 測試 test_loader = DataLoader(testset, batch_size=64, shuffle=False, collate_fn=collate) model.eval() test_pred, test_label = [], [] with torch.no_grad(): for it, (batchg, label) in enumerate(test_loader): batchg, label = batchg.to(DEVICE), label.to(DEVICE) pred = torch.softmax(model(batchg), 1) pred = torch.max(pred, 1)[1].view(-1) test_pred += pred.detach().cpu().numpy().tolist() test_label += label.cpu().numpy().tolist() print("Test accuracy: ", accuracy_score(test_label, test_pred))
運行結果: