論文信息
論文標題:Inductive Representation Learning on Large Graphs
論文作者:William L. Hamilton, Rex Ying
論文來源:2017, NIPS
論文地址:download
論文代碼:download
1 Introduction
創新:基於采樣和聚合的算法。
1.1 Transductive Learning
即直推式學習,已經預先觀察了所有數據,含訓練和測試數據集。 從已經觀察到的數據集中學習,然后預測測試數據集的標簽。 即過程會利用這些不知道數據標簽的測試集數據的模式和其他信息。

def load_inductive_dataset(dataset_name): if dataset_name == "ppi": batch_size = 2
# define loss function
# create the dataset
train_dataset = PPIDataset(mode='train') valid_dataset = PPIDataset(mode='valid') test_dataset = PPIDataset(mode='test') train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size) valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size, shuffle=False) test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False) eval_train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=False) g = train_dataset[0] num_classes = train_dataset.num_labels num_features = g.ndata['feat'].shape[1] else: _args = namedtuple("dt", "dataset") dt = _args(dataset_name) batch_size = 1 dataset = load_data(dt) print("dataset = ",dataset) num_classes = dataset.num_classes g = dataset[0] num_features = g.ndata["feat"].shape[1] train_mask = g.ndata['train_mask'] feat = g.ndata["feat"] feat = scale_feats(feat) g.ndata["feat"] = feat g = g.remove_self_loop() g = g.add_self_loop() train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64) train_g = dgl.node_subgraph(g, train_nid) train_dataloader = [train_g] valid_dataloader = [g] test_dataloader = valid_dataloader eval_train_dataloader = [train_g] return train_dataloader, valid_dataloader, test_dataloader, eval_train_dataloader, num_features, num_classes
GCN 就是一個典型的例子:

def train(epoch): t = time.time() model.train() optimizer.zero_grad() output = model(features, adj) loss_train = F.nll_loss(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() if not args.fastmode: # Evaluate validation set performance separately,
# deactivates dropout during validation run.
model.eval() output = model(features, adj) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) def test(): model.eval() output = model(features, adj) loss_test = F.nll_loss(output[idx_test], labels[idx_test]) acc_test = accuracy(output[idx_test], labels[idx_test])
缺點:一旦有新的節點出現,直推式學習需要重新訓練模型。
1.2 Inductive Learning
即歸納式學習,只能使用已經觀測到的數據(有標簽),對於沒有標簽的節點在訓練過程中只能忽略(不使用結構信息和屬性信息)。

def load_dataset(dataset_name): assert dataset_name in GRAPH_DICT, f"Unknow dataset: {dataset_name}."
if dataset_name.startswith("ogbn"): dataset = GRAPH_DICT[dataset_name](dataset_name) else: dataset = GRAPH_DICT[dataset_name]() if dataset_name == "ogbn-arxiv": graph, labels = dataset[0] num_nodes = graph.num_nodes() split_idx = dataset.get_idx_split() train_idx, val_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] graph = preprocess(graph) if not torch.is_tensor(train_idx): train_idx = torch.as_tensor(train_idx) val_idx = torch.as_tensor(val_idx) test_idx = torch.as_tensor(test_idx) feat = graph.ndata["feat"] feat = scale_feats(feat) graph.ndata["feat"] = feat train_mask = torch.full((num_nodes,), False).index_fill_(0, train_idx, True) val_mask = torch.full((num_nodes,), False).index_fill_(0, val_idx, True) test_mask = torch.full((num_nodes,), False).index_fill_(0, test_idx, True) graph.ndata["label"] = labels.view(-1) graph.ndata["train_mask"], graph.ndata["val_mask"], graph.ndata["test_mask"] = train_mask, val_mask, test_mask else: graph = dataset[0] graph = graph.remove_self_loop() graph = graph.add_self_loop() num_features = graph.ndata["feat"].shape[1] num_classes = dataset.num_classes return graph, (num_features, num_classes)
主要觀點是:節點的嵌入可以通過一個共同的聚合鄰居節點信息的函數得到,在訓練時只要得到這個聚合函數,就可以將其泛化到未知的節點上。
2 GraphSAGE Method
GraphSAGE 的核心思想:不是試圖學習一個圖上所有 Node Embedding,而是學習一個為每個 Node 產生 Embedding 的映射(即產生一個通用的映射函數)。
本文提出的 GraphSAGE(Inductive Method) 可以利用所有圖中存在的結構特征(如:節點度,鄰居信息),去推測未知的節點表示。
舉例如下:
- 先對鄰居隨機采樣,降低計算復雜度(Figure 1 :一跳鄰居采樣數=3,二跳鄰居采樣數=5)
- 生成目標節點 Emebedding:先聚合2跳鄰居特征,生成一跳鄰居 Embedding,再聚合一跳鄰居 Embedding,生成目標節點 Embedding,從而獲得二跳鄰居信息。
- 將 Embedding 作為全連接層的輸入,預測目標節點的標簽。
2.1 Embedding generation algorithm
GraphSAGE 算法如下:
注意:$K$ 控制着跳數,本文這邊取 $K=2$。
舉例:
這里以節點 $1$ 為例,采用均值聚合。
對於節點 $1$ ,它相連的鄰居為 ${3,4,5,6}$。(這里以聚合所有鄰居信息為例)
對於算法中的第 4 步:$h_{\mathcal{N}(1)}^{1} \leftarrow A G G R E G A T E\left(\left\{h_{3}^{0}, h_{4}^{0}, h_{5}^{0}, h_{6}^{0}\right\}\right)$:
$h_{\mathcal{N}(1)}^{1}=A G G R E G A T E\left(\left\{h_{3}^{0}, h_{4}^{0}, h_{5}^{0}, h_{6}^{0}\right\}\right)=\operatorname{Mean}([0.3,0.4],[0.2,0.2],[0.7,0.8],[0.5,0.6]$
對於算法中的第 5 步:$h_{1}^{1} \leftarrow \sigma\left(W^{1} \cdot \operatorname{CONCAT}\left(h_{1}^{0}, h_{\mathcal{N}(1)}^{1}\right)\right)$ :
$\left.h_{1}^{1}=W \cdot \operatorname{CONCAT}\left(h_{1}^{0}, h_{\mathcal{N}(1)}^{1}\right)\right)=W \cdot[0.1,0.2,0.425,0.5]$
改進:聚合部分鄰居
-
- 對於節點 $1$,比如我們要聚合其 $3$ 個鄰居的信息,那就按均勻分布隨機在其鄰居集合中選擇 $3$ 個鄰居節點。(節點不重復)
- 對於節點 $1$,比如我們要聚合其 $6$ 個鄰居的信息,那就先聚合其所有鄰居一次($5$ 個鄰居),然后在按均勻分布隨機在其鄰居集合中選擇 $1$ 個鄰居節點。(節點重復)
注意點:上述提到 $K$ 控制着跳數。
舉例:【$K=2,S_1 =2,S_2 = 3$】
本文實驗說明聚合鄰居數最好滿足: $S_{1} \cdot S_{2} \leq 500$。
基於 minibatch 版本的 GraphSAGE 算法:
舉例:
考慮:
假設:$K=2, S_1=2, S_2=3$,$\mathcal{B}^{2}=\{a\}$
那么:
$\mathcal{B}^{1}=\{a\} \cup \mathcal{N}_{2}(a)=\{a\} \cup\{c, f, j\}$
$\mathcal{B}^{0}=\{a\} \cup\{c, f, j\} \cup \mathcal{N}_{1}(\{c, f, j\})=\{a\} \cup\{c, f, j\} \cup\{d, e, i, h, k, l\}$
考慮:
$\begin{array}{l}\mathcal{B}^{1}=\{a\} \cup \mathcal{N}_{2}(a)=\{a\} \cup\{c, f, j\} \\\mathcal{N}_{1}(c)=\{d, e\} \\h_{\mathcal{N}(c)}^{1} \leftarrow A G G R E G A T E_{1}\left\{h_{d}^{0}, h_{e}^{0}\right\} \\h_{c}^{1} \leftarrow \sigma\left(W^{1} \cdot \operatorname{CONCAT}\left(h_{c}^{0}, h_{\mathcal{N}(1)}^{1}\right)\right)\end{array}$
2.2 Learning the parameters of GraphSAGE
損失函數分為基於圖的無監督損失和有監督損失。
- 基於圖的無監督損失:目標是使節點 $u$ 與 “鄰居” $v$ 的 Embedding 相似,與無邊相連的節點 $v_n$ 不相似。
$J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)$
其中:
-
- 節點 $v$ 是節點 $u$ 經過固定長度的 Random walk 到達的鄰居節點;
- $v_{n} \sim P_{n}(u)$ 表示負采樣:節點 $v_{n}$ 是從節點 $u$ 的負采樣分布 $P_{n}$ 采樣的, $Q$ 為采樣樣本數;
- 基於圖的有監督損失:無監督損失函數的設定來學習節點 Embedding 可以供下游多個任務使用,若僅使用在特定某個任務上,則可以替代上述損失函數符合特定任務目標,如交叉熵。
2.3 Aggregator Architectures
由於節點是無序的,所以聚合器需要滿足排列不變性。
排列不變性(permutation invariance):指輸入的順序改變不會影響輸出的值。
- Mean aggregator
$h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{mean}\left(\left\{h_{v}^{k-1}\right\} \cup\left\{h_{u}^{k-1}, \forall u \in N(v)\right\}\right)\right.$
- LSTM aggregator
LSTM函數不符合 "排列不變性" 的性質,需要先對鄰居隨機排序,然后將隨機的鄰居序列 Embedding $ \left\{x_{t}, t \in N(v)\right\}$ 作為 LSTM 輸入。
-
Pooling aggregator
一個 element-wise max pooling 操作應用在鄰居集合上來聚合信息:
$\text { AGGREGATE }_{k}^{\mathrm{pool}}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right)$
$\mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right)$
3 Experiments
基線實驗
消融實驗
修改時間
2022-01-17 創建文章
2022-06-07 修改文中關於直推式和歸納式學習的定義