圖神經網絡-環境配置與PyG庫


環境配置與PyG中圖與圖數據集的表示和使用

一、引言

PyTorch Geometric (PyG)是面向幾何深度學習的PyTorch的擴展庫,幾何深度學習指的是應用於圖和其他不規則、非結構化數據的深度學習。基於PyG庫,我們可以輕松地根據數據生成一個圖對象,然后很方便的使用它;我們也可以容易地為一個圖數據集構造一個數據集類,然后很方便的將它用於神經網絡。

通過此節的實踐內容,我們將

  1. 首先學習程序運行環境的配置
  2. 接着學習PyG中圖數據的表示及其使用,即學習PyG中Data類。
  3. 最后學習PyG中圖數據集的表示及其使用,即學習PyG中Dataset類。

二、環境配置

  1. 使用nvidia-smi命令查詢顯卡驅動是否正確安裝

image-20210515204452045

  1. 安裝正確版本的pytorch和cudatoolkit,此處安裝1.8.1版本的pytorch和11.1版本的cudatoolkit

    1. conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia
    2. 確認是否正確安裝,正確的安裝應出現下方的結果
    $ python -c "import torch; print(torch.__version__)"
    # 1.8.1
    $ python -c "import torch; print(torch.version.cuda)"
    # 11.1
    
  2. 安裝正確版本的PyG

    pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    pip install torch-geometric
    

其他版本的安裝方法以及安裝過程中出現的大部分問題的解決方案可以在Installation of of PyTorch Geometric 頁面找到。

三、Data類——PyG中圖的表示及其使用

Data對象的創建

Data類的官方文檔為torch_geometric.data.Data

通過構造函數

Data類的構造函數

class Data(object):

    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
    r"""
    Args:
        x (Tensor, optional): 節點屬性矩陣,大小為`[num_nodes, num_node_features]`
        edge_index (LongTensor, optional): 邊索引矩陣,大小為`[2, num_edges]`,第0行為尾節點,第1行為頭節點,頭指向尾
        edge_attr (Tensor, optional): 邊屬性矩陣,大小為`[num_edges, num_edge_features]`
        y (Tensor, optional): 節點或圖的標簽,任意大小(,其實也可以是邊的標簽)
	
    """
    self.x = x
    self.edge_index = edge_index
    self.edge_attr = edge_attr
    self.y = y

    for key, item in kwargs.items():
        if key == 'num_nodes':
            self.__num_nodes__ = item
        else:
            self[key] = item

edge_index的每一列定義一條邊,其中第一行為邊起始節點的索引,第二行為邊結束節點的索引。這種表示方法被稱為COO格式(coordinate format),通常用於表示稀疏矩陣。PyG不是用稠密矩陣\(\mathbf{A} \in \{ 0, 1 \}^{|\mathcal{V}| \times |\mathcal{V}|}\)來持有鄰接矩陣的信息,而是用僅存儲鄰接矩陣\(\mathbf{A}\)中非\(0\)元素的稀疏矩陣來表示圖。

通常,一個圖至少包含x, edge_index, edge_attr, y, num_nodes5個屬性,當圖包含其他屬性時,我們可以通過指定額外的參數使Data對象包含其他的屬性

graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, num_nodes=num_nodes, other_attr=other_attr)

dict對象為Data對象

我們也可以將一個dict對象轉換為一個Data對象

graph_dict = {
    'x': x,
    'edge_index': edge_index,
    'edge_attr': edge_attr,
    'y': y,
    'num_nodes': num_nodes,
    'other_attr': other_attr
}
graph_data = Data.from_dict(graph_dict)

from_dict是一個類方法:

@classmethod
def from_dict(cls, dictionary):
    r"""Creates a data object from a python dictionary."""
    data = cls()
    for key, item in dictionary.items():
        data[key] = item

    return data

注意graph_dict中屬性值的類型與大小的要求與Data類的構造函數的要求相同。

Data對象轉換成其他類型數據

我們可以將Data對象轉換為dict對象:

def to_dict(self):
    return {key: item for key, item in self}

或轉換為namedtuple

def to_namedtuple(self):
    keys = self.keys
    DataTuple = collections.namedtuple('DataTuple', keys)
    return DataTuple(*[self[key] for key in keys])

獲取Data對象屬性

x = graph_data['x']

設置Data對象屬性

graph_data['x'] = x

獲取Data對象包含的屬性的關鍵字

graph_data.keys()

對邊排序並移除重復的邊

graph_data.coalesce()

Data對象的其他性質

我們通過觀察PyG中內置的一個圖來查看Data對象的性質:

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
data = dataset[0]  # Get the first graph object.
print(data)
print('==============================================================')

# 獲取圖的一些信息
print(f'Number of nodes: {data.num_nodes}') # 節點數量
print(f'Number of edges: {data.num_edges}') # 邊數量
print(f'Number of node features: {data.num_node_features}') # 節點屬性的維度
print(f'Number of node features: {data.num_features}') # 同樣是節點屬性的維度
print(f'Number of edge features: {data.num_edge_features}') # 邊屬性的維度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') # 平均節點度
print(f'if edge indices are ordered and do not contain duplicate entries.: {data.is_coalesced()}') # 是否邊是有序的同時不含有重復的邊
print(f'Number of training nodes: {data.train_mask.sum()}') # 用作訓練集的節點
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}') # 用作訓練集的節點的數量
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}') # 此圖是否包含孤立的節點
print(f'Contains self-loops: {data.contains_self_loops()}')  # 此圖是否包含自環的邊
print(f'Is undirected: {data.is_undirected()}')  # 此圖是否是無向圖

四、Dataset類——PyG中圖數據集的表示及其使用

PyG內置了大量常用的基准數據集,接下來我們以PyG內置的Planetoid數據集為例,來學習PyG中圖數據集的表示及使用

Planetoid數據集類的官方文檔為torch_geometric.datasets.Planetoid

生成數據集對象並分析數據集

如下方代碼所示,在PyG中生成一個數據集是簡單直接的。在第一次生成PyG內置的數據集時,程序首先下載原始文件,然后將原始文件處理成包含Data對象的Dataset對象並保存到文件。

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/dataset/Cora', name='Cora')
# Cora()

len(dataset)
# 1

dataset.num_classes
# 7

dataset.num_node_features
# 1433

分析數據集中樣本

可以看到該數據集只有一個圖,包含7個分類任務,節點的屬性為1433維度。

data = dataset[0]
# Data(edge_index=[2, 10556], test_mask=[2708],
#         train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

data.is_undirected()
# True

data.train_mask.sum().item()
# 140

data.val_mask.sum().item()
# 500

data.test_mask.sum().item()
# 1000

現在我們看到該數據集包含的唯一的圖,有2708個節點,節點特征為1433維,有10556條邊,有140個用作訓練集的節點,有500個用作驗證集的節點,有1000個用作測試集的節點。PyG內置的其他數據集,請小伙伴一一試驗,以觀察不同數據集的不同。

數據集的使用

假設我們定義好了一個圖神經網絡模型,其名為Net。在下方的代碼中,我們展示了節點分類圖數據集在訓練過程中的使用。

model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

結語

通過此實踐環節,我們學習了程序運行環境的配置PyG中Data對象的生成與使用、以及PyG中Dataset對象的表示和使用。此節內容是圖神經網絡實踐的基礎,所涉及的內容是最常用、最基礎的,在后面的內容中我們還將學到復雜Data類的構建,和復雜Dataset類的構建。

作業

  • 請通過繼承Data類實現一個類專門用於表示“機構-作者-論文”的網絡。該網絡包含“機構“、”作者“和”論文”三類節點,以及“作者-機構“和“作者-論文“兩類邊。對要實現的類的要求:1)用不同的屬性存儲不同節點的屬性;2)用不同的屬性存儲不同的邊(邊沒有屬性);3)逐一實現獲取不同節點數量的方法。

參考資料


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM