圖神經網絡 PyTorch Geometric 入門教程


簡介

Graph Neural Networks 簡稱 GNN,稱為圖神經網絡,是深度學習中近年來一個比較受關注的領域。近年來 GNN 在學術界受到的關注越來越多,與之相關的論文數量呈上升趨勢,GNN 通過對信息的傳遞,轉換和聚合實現特征的提取,類似於傳統的 CNN,只是 CNN 只能處理規則的輸入,如圖片等輸入的高、寬和通道數都是固定的,而 GNN 可以處理不規則的輸入,如點雲等。 可查看【GNN】萬字長文帶你入門 GCN

而 PyTorch Geometric Library (簡稱 PyG) 是一個基於 PyTorch 的圖神經網絡庫,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相關論文中的方法實現和常用數據集,並且提供了簡單易用的接口來生成圖,因此對於復現論文來說也是相當方便。用法大多數和 PyTorch 很相近,因此熟悉 PyTorch 的同學使用這個庫可以很快上手。

torch_geometric.data.Data

節點和節點之間的邊構成了圖。所以在 PyG 中,如果你要構建圖,那么需要兩個要素:節點和邊。PyG 提供了torch_geometric.data.Data (下面簡稱Data) 用於構建圖,包括 5 個屬性,每一個屬性都不是必須的,可以為空。

  • x: 用於存儲每個節點的特征,形狀是[num_nodes, num_node_features]
  • edge_index: 用於存儲節點之間的邊,形狀是 [2, num_edges]
  • pos: 存儲節點的坐標,形狀是[num_nodes, num_dimensions]
  • y: 存儲樣本標簽。如果是每個節點都有標簽,那么形狀是[num_nodes, *];如果是整張圖只有一個標簽,那么形狀是[1, *]
  • edge_attr: 存儲邊的特征。形狀是[num_edges, num_edge_features]

實際上,Data對象不僅僅限制於這些屬性,我們可以通過data.face來擴展Data,以張量保存三維網格中三角形的連接性。

需要注意的的是,在Data里包含了樣本的 label,這意味和 PyTorch 稍有不同。在PyTorch中,我們重寫Dataset__getitem__(),根據 index 返回對應的樣本和 label。在 PyG 中,我們使用的不是這種寫法,而是在get()函數中根據 index 返回torch_geometric.data.Data類型的數據,在Data里包含了數據和 label。

下面一個例子是未加權無向圖 ( 未加權指邊上沒有權值 ),包括 3 個節點和 4 條邊。


由於是無向圖,因此有 4 條邊:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每個節點都有自己的特征。上面這個圖可以使用`torch_geometric.data.Data`來表示如下:
import torch
from torch_geometric.data import Data
# 由於是無向圖,因此有 4 條邊:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
# 節點的特征                           
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

注意edge_index中邊的存儲方式,有兩個list,第 1 個list是邊的起始點,第 2 個list是邊的目標節點。注意與下面的存儲方式的區別。

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

這種情況edge_index需要先轉置然后使用contiguous()方法。關於contiguous()函數的作用,查看 PyTorch中的contiguous

最后再復習一遍,Data中最基本的 4 個屬性是xedge_indexposy,我們一般都需要這 4 個參數。

有了Data,我們可以創建自己的Dataset,讀取並返回Data了。

Dataset 與 DataLoader

PyG 的 Dataset繼承自torch.utils.data.Dataset,自帶了很多圖數據集,我們以TUDataset為例,通過以下代碼就可以加載數據集,root參數設置數據下載的位置。通過索引可以訪問每一個數據。

from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
data = dataset[0]

在一個圖中,由edge_indexedge_attr可以決定所有節點的鄰接矩陣。PyG 通過創建稀疏的對角鄰接矩陣,並在節點維度中連接特征矩陣和 label 矩陣,實現了在 mini-batch 的並行化。PyG 允許在一個 mini-batch 中的每個Data (圖) 使用不同數量的節點和邊。


# 自定義 Dataset

盡管 PyG 已經包含許多有用的數據集,我們也可以通過繼承torch_geometric.data.Dataset使用自己的數據集。提供 2 種不同的Dataset

  • InMemoryDataset:使用這個Dataset會一次性把數據全部加載到內存中。
  • Dataset: 使用這個Dataset每次加載一個數據到內存中,比較常用。

我們需要在自定義的Dataset的初始化方法中傳入數據存放的路徑,然后 PyG 會在這個路徑下再划分 2 個文件夾:

  • raw_dir: 存放原始數據的路徑,一般是 csv、mat 等格式
  • processed_dir: 存放處理后的數據,一般是 pt 格式 ( 由我們重寫process()方法實現)。

在 PyTorch 中,是沒有這兩個文件夾的。下面來說明一下這兩個文件夾在 PyG 中的實際意義和處理邏輯。

torch_geometric.data.Dataset繼承自torch.utils.data.Dataset,在初始化方法 __init__()中,會調用_download()方法和_process()方法。

def __init__(self, root=None, transform=None, pre_transform=None,
			 pre_filter=None):
	super(Dataset, self).__init__()

	if isinstance(root, str):
		root = osp.expanduser(osp.normpath(root))

	self.root = root
	self.transform = transform
	self.pre_transform = pre_transform
	self.pre_filter = pre_filter
	self.__indices__ = None

	# 執行 self._download() 方法
	if 'download' in self.__class__.__dict__.keys():
		self._download()
    # 執行 self._process() 方法
	if 'process' in self.__class__.__dict__.keys():
		self._process()

_download()方法如下,首先檢查self.raw_paths列表中的文件是否存在;如果存在,則返回;如果不存在,則調用self.download()方法下載文件。

def _download(self):
	if files_exist(self.raw_paths):  # pragma: no cover
		return

	makedirs(self.raw_dir)
	self.download()

_process()方法如下,首先在self.processed_dir中有pre_transform,那么判斷這個pre_transform和傳進來的pre_transform是否一致,如果不一致,那么警告提示用戶先刪除self.processed_dir文件夾。pre_filter同理。

然后檢查self.processed_paths列表中的文件是否存在;如果存在,則返回;如果不存在,則調用self.process()生成文件。

def _process(self):
	f = osp.join(self.processed_dir, 'pre_transform.pt')
	if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
		warnings.warn(
			'The `pre_transform` argument differs from the one used in '
			'the pre-processed version of this dataset. If you really '
			'want to make use of another pre-processing technique, make '
			'sure to delete `{}` first.'.format(self.processed_dir))
	f = osp.join(self.processed_dir, 'pre_filter.pt')
	if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
		warnings.warn(
			'The `pre_filter` argument differs from the one used in the '
			'pre-processed version of this dataset. If you really want to '
			'make use of another pre-fitering technique, make sure to '
			'delete `{}` first.'.format(self.processed_dir))

	if files_exist(self.processed_paths):  # pragma: no cover
		return

	print('Processing...')

	makedirs(self.processed_dir)
	self.process()

	path = osp.join(self.processed_dir, 'pre_transform.pt')
	torch.save(__repr__(self.pre_transform), path)
	path = osp.join(self.processed_dir, 'pre_filter.pt')
	torch.save(__repr__(self.pre_filter), path)

	print('Done!')

一般來說不用實現downloand()方法

如果你直接把處理好的 pt 文件放在了self.processed_dir中,那么也不用實現process()方法。

在 Pytorch 的dataset中,我們需要實現__getitem__()方法,根據index返回樣本和標簽。在這里torch_geometric.data.Dataset中,重寫了__getitem__()方法,其中調用了get()方法獲取數據。

def __getitem__(self, idx):
	if isinstance(idx, int):
		data = self.get(self.indices()[idx])
		data = data if self.transform is None else self.transform(data)
		return data
	else:
		return self.index_select(idx)

我們需要實現的是get()方法,根據index返回torch_geometric.data.Data類型的數據。

process()方法存在的意義是原始的格式可能是 csv 或者 mat,在process()函數里可以轉化為 pt 格式的文件,這樣在get()方法中就可以直接使用torch.load()函數讀取 pt 格式的文件,返回的是torch_geometric.data.Data類型的數據,而不用在get()方法做數據轉換操作 (把其他格式的數據轉換為 torch_geometric.data.Data類型的數據)。當然我們也可以提前把數據轉換為 torch_geometric.data.Data類型,使用 pt 格式保存在self.processed_dir中。

DataLoader

通過torch_geometric.data.DataLoader可以方便地使用 mini-batch。

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
	# 對每一個 mini-batch 進行操作
	...

torch_geometric.data.Batch繼承自torch_geometric.data.Data,並且多了一個屬性:batchbatch是一個列向量,它將每個元素映射到每個 mini-batch 中的相應圖:

batch $=\left[\begin{array}{cccccccc}0 & \cdots & 0 & 1 & \cdots & n-2 & n-1 & \cdots & n-1\end{array}\right]^{\top}$

我們可以使用它分別為每個圖的節點維度計算平均的節點特征:

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    data
    #data: Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    x = scatter_mean(data.x, data.batch, dim=0)
    # x.size(): torch.Size([32, 21])

關於 batching 的流程細節,你可以點擊這里查看。關於scatter 方法的說明,你可以查看torch-scatter說明文檔

Transforms

transforms在計算機視覺領域是一種很常見的數據增強。PyG 有自己的transforms,輸出是Data類型,輸出也是Data類型。可以使用torch_geometric.transforms.Compose封裝一系列的transforms。我們以 ShapeNet 數據集 (包含 17000 個 point clouds,每個 point 分類為 16 個類別的其中一個) 為例,我們可以使用transforms從 point clouds 生成最近鄰圖:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

還可以通過transform在一定范圍內隨機平移每個點,增加坐標上的擾動,做數據增強:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

模型訓練

這里只是展示一個簡單的 GCN 模型構造和訓練過程,沒有用到DatasetDataLoader

我們將使用一個簡單的 GCN 層,並在 Cora 數據集上實驗。有關 GCN 的更多內容,請查看這篇博客

我們首先加載數據集:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')

然后定義 2 層的 GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

然后訓練 200 個 epochs:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

最后在測試集上驗證了模型的准確率:

model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))


如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。

我的文章會首發在公眾號上,歡迎掃碼關注我的公眾號張賢同學



免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM