論文解讀(GRACE)《Deep Graph Contrastive Representation Learning》


論文解讀

論文標題:Deep Graph Contrastive Representation Learning
論文作者:Yanqiao Zhu, Yichen Xu, Feng Yu, Q. Liu, Shu Wu, Liang Wang
論文來源:2020, ArXiv
論文地址:download 
代碼地址:download (代碼寫的不錯)

1 Introduction

  節點級圖對比學習框架。

  數據增強:邊刪除、特征隱藏。

2 Method

  GRACE 框架如下:

  

class Encoder(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, activation, base_model=GCNConv, k: int = 2): super(Encoder, self).__init__() self.base_model = base_model assert k >= 2 self.k = k self.conv = [base_model(in_channels, 2 * out_channels)] for _ in range(1, k-1): self.conv.append(base_model(2 * out_channels, 2 * out_channels)) self.conv.append(base_model(2 * out_channels, out_channels)) self.conv = nn.ModuleList(self.conv) self.activation = activation def forward(self, x: torch.Tensor, edge_index: torch.Tensor): for i in range(self.k): x = self.activation(self.conv[i](x, edge_index)) return x class Model(torch.nn.Module): def __init__(self, encoder: Encoder, num_hidden: int, num_proj_hidden: int, tau: float = 0.5): super(Model, self).__init__() self.encoder: Encoder = encoder self.tau: float = tau self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden) def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: return self.encoder(x, edge_index) def projection(self, z: torch.Tensor) -> torch.Tensor: z = F.elu(self.fc1(z)) return self.fc2(z) def sim(self, z1: torch.Tensor, z2: torch.Tensor): z1 = F.normalize(z1) z2 = F.normalize(z2) return torch.mm(z1, z2.t()) def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor): f = lambda x: torch.exp(x / self.tau) refl_sim = f(self.sim(z1, z1)) between_sim = f(self.sim(z1, z2)) return -torch.log( between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())) def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int): # Space complexity: O(BN) (semi_loss: O(N^2))
        device = z1.device num_nodes = z1.size(0) num_batches = (num_nodes - 1) // batch_size + 1 f = lambda x: torch.exp(x / self.tau) indices = torch.arange(0, num_nodes).to(device) losses = [] for i in range(num_batches): mask = indices[i * batch_size:(i + 1) * batch_size] refl_sim = f(self.sim(z1[mask], z1))  # [B, N]
            between_sim = f(self.sim(z1[mask], z2))  # [B, N]
 losses.append(-torch.log( between_sim[:, i * batch_size:(i + 1) * batch_size].diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) return torch.cat(losses) def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True, batch_size: int = 0): h1 = self.projection(z1) h2 = self.projection(z2) if batch_size == 0: l1 = self.semi_loss(h1, h2) l2 = self.semi_loss(h2, h1) else: l1 = self.batched_semi_loss(h1, h2, batch_size) l2 = self.batched_semi_loss(h2, h1, batch_size) ret = (l1 + l2) * 0.5 ret = ret.mean() if mean else ret.sum() return ret
Model Code

2.1 The Contrastive Learning Framework

  首先,通過隨機破壞原始圖來生成兩個視圖,分別為  $G_{1}$  和  $G_{2}$ 。

  其次,通過 Encoder 獲得節點表示, 分別為 $U=f\left(\widetilde{\boldsymbol{X}}_{1}, \widetilde{\boldsymbol{A}}_{1}\right) $ 和  $V=f\left(\widetilde{\boldsymbol{X}}_{2}, \widetilde{\boldsymbol{A}}_{2}\right) $ 。

  對比目標:采用視圖之間的節點一致性,即:對於任何節點  $v_{i}$,它在一 個視圖中生成的嵌入  $\boldsymbol{u}_{i}$  被視為錨嵌入,在另一個視圖中 $v_{i}$ 生成的節點嵌入 $\boldsymbol{v}_{i}$ 為正樣本,在兩個視圖中除  $v_{i}$  以外的節點表示被視為負樣本。

  每個正對 $\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{i}\right)$  的對目標定義為:

    ${\large \ell\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{i}\right)=-\log \frac{e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{i}\right) / \tau}}{\underbrace{e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{i}\right) / \tau}}_{\text {the positive pair }}+\underbrace{\sum\limits_{k=1}^{N} \mathbb{1}_{[k \neq i]} e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{k}\right) / \tau}}_{\text {inter-view negative pairs }}+\underbrace{\sum\limits_{k=1}^{N} \mathbb{1}_{[k \neq i]} e^{\theta\left(\boldsymbol{u}_{i}, \boldsymbol{u}_{k}\right) / \tau}}_{\text {intra-view negative pairs }}}} \quad\quad\quad\quad(1)$

  其中 ,$\theta(\boldsymbol{u}, \boldsymbol{v})=s(g(\boldsymbol{u}), g(\boldsymbol{v})) $ 代表余弦相似度距離。

  因此,總體目標函數為:

    $\mathcal{J}=\frac{1}{2 N} \sum\limits _{i=1}^{N}\left[\ell\left(\boldsymbol{u}_{i}, \boldsymbol{v}_{i}\right)+\ell\left(\boldsymbol{v}_{i}, \boldsymbol{u}_{i}\right)\right]\quad\quad\quad\quad(2)$

    def loss(self, z1: torch.Tensor, z2: torch.Tensor,mean: bool = True, batch_size: int = 0): h1 = self.projection(z1) h2 = self.projection(z2) if batch_size == 0: l1 = self.semi_loss(h1, h2) l2 = self.semi_loss(h2, h1) else: l1 = self.batched_semi_loss(h1, h2, batch_size) l2 = self.batched_semi_loss(h2, h1, batch_size) ret = (l1 + l2) * 0.5 ret = ret.mean() if mean else ret.sum() return ret def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor): f = lambda x: torch.exp(x / self.tau) refl_sim = f(self.sim(z1, z1)) between_sim = f(self.sim(z1, z2)) return -torch.log( between_sim.diag()/ (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()) ) def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int): # Space complexity: O(BN) (semi_loss: O(N^2))
        device = z1.device num_nodes = z1.size(0) num_batches = (num_nodes - 1) // batch_size + 1 f = lambda x: torch.exp(x / self.tau) indices = torch.arange(0, num_nodes).to(device) losses = [] for i in range(num_batches): mask = indices[i * batch_size:(i + 1) * batch_size] refl_sim = f(self.sim(z1[mask], z1))  # [B, N]
            between_sim = f(self.sim(z1[mask], z2))  # [B, N]
 losses.append(-torch.log( between_sim[:, i * batch_size:(i + 1) * batch_size].diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) return torch.cat(losses)
Loss function Code

  GRACE 算法流程:

  

2.2 Graph View Generation

2.2.1 Removing edges (RE)

  首先采樣一個隨機掩蔽矩陣 $\widetilde{\boldsymbol{R}} \in\{0,1\}^{N \times N}$,矩陣中每個元素依據伯努利分布生成。如果 $\boldsymbol{A}_{i j}=1$ ,則它的值來自伯努利分布 $\widetilde{\boldsymbol{R}}_{i j} \sim \mathcal{B}\left(1-p_{r}\right) $ ,否則 $\widetilde{\boldsymbol{R}}_{i j}=0 $ 。這里的 $p_{r}$ 是每條邊被刪除的概率。所得到的鄰接矩陣可以計算為

    $\widetilde{\boldsymbol{A}}=\boldsymbol{A} \circ \widetilde{\boldsymbol{R}}\quad\quad\quad(3)$

  其中:$(\boldsymbol{x} \circ \boldsymbol{y})_{i}=x_{i} y_{i}$ 代表着 Hadamard product 。

edge_index_1 = dropout_adj(edge_index, p=drop_edge_rate_1)[0] edge_index_2 = dropout_adj(edge_index, p=drop_edge_rate_2)[0]
Drop edge Code

2.2.2 Masking node features (MF)

  首先對隨機向量  $\widetilde{m} \in\{0,1\}^{F}$  進行采樣,其中它的每個維度值都獨立地從概率為  $1-p_{m}$  的伯努利分布中提取,即  $\widetilde{m}_{i} \sim   \mathcal{B}\left(1-p_{m}\right) $ 。然后,生成的節點特征  $\widetilde{\boldsymbol{X}}$  為:

    $\tilde{\boldsymbol{X}}=\left[\boldsymbol{x}_{1} \circ \widetilde{\boldsymbol{m}} ; \boldsymbol{x}_{2} \circ \widetilde{\boldsymbol{m}} ; \cdots ; \boldsymbol{x}_{N} \circ \widetilde{\boldsymbol{m}}\right]^{\top}\quad\quad\quad\quad(4)$

  其中:$[\cdot ;]$ 代表着拼接操作。

def drop_feature(x, drop_prob): drop_mask = torch.empty( (x.size(1), ), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob x = x.clone() x[:, drop_mask] = 0 return x
drop_feature Code

  本文共同利用這兩種方法來生成視圖。  $\tilde{\mathcal{G}}_{1}$  和  $\widetilde{\mathcal{G}}_{2}$  的生成由兩個超參數  $p_{r}$  和  $p_{m}$  控制。為了在這兩個視圖中提供不同的上下文,這兩個視圖的生成過程使用了兩組不同的超參數  $p_{r, 1}$ 、 $p_{m, 1}$  和  $p_{r, 2}$ 、$ p_{m, 2}$  。實驗表明,我們的模型對  $p_{r}$  和  $p_{m}$  的選擇不敏感,因此原始圖沒有過度損壞,例如,$p_{r} \leq 0.8$  和  $p_{m} \leq 0.8$  。

3 Experiments

3.1 Dataset

  

3.2 Experimental Setup

Transductive learning

  在 Transductive learning 中,使用 $2$ 層的 GCN 作為 Encoder:

    $\mathrm{GC}_{i}(\boldsymbol{X}, \boldsymbol{A}) =\sigma\left(\hat{\boldsymbol{D}}^{-\frac{1}{2}} \hat{\boldsymbol{A}} \hat{\boldsymbol{D}}^{-\frac{1}{2}} \boldsymbol{X} \boldsymbol{W}_{i}\right)\quad\quad\quad\quad(7)$

    $f(\boldsymbol{X}, \boldsymbol{A})=\mathrm{GC}_{2}\left(\mathrm{GC}_{1}(\boldsymbol{X}, \boldsymbol{A}), \boldsymbol{A}\right)\quad\quad\quad\quad(8)$

Inductive learning on large graphs

  考慮到 Reddit 數據的大規模,本文采用具有殘差連接的三層 GraphSAGE-GCN 作為編碼器,其表述為

    $\widehat{\mathrm{MP}}_{i}(\boldsymbol{X}, \boldsymbol{A}) =\sigma\left(\left[\hat{\boldsymbol{D}}^{-1} \hat{\boldsymbol{A}} \boldsymbol{X} ; \boldsymbol{X}\right] \boldsymbol{W}_{i}\right) \quad\quad\quad\quad(9)$

    $f(\boldsymbol{X}, \boldsymbol{A}) =\widehat{\mathrm{MP}}_{3}\left(\widehat{\mathrm{MP}}_{2}\left(\widehat{\mathrm{MP}}_{1}(\boldsymbol{X}, \boldsymbol{A}), \boldsymbol{A}\right), \boldsymbol{A}\right)\quad\quad\quad\quad(10)$

  對與像 Reddit 一樣的大規模數據集,我們應用子采樣方法,首先隨機選擇一批節點,然后通過對節點鄰居進行替換,得到以每個所選節點為中心的子圖。具體來說,我們分別在 1-hop,2-hop 和 3-hop采樣 30、25、20 個鄰居。 

Inductive learning on multiple graphs.

  對於多圖 PPI 的歸納學習,我們疊加了三個具有跳躍連接的平均池化層,類似於 DGI 。圖卷積編碼器可以表示為

    $\boldsymbol{H}_{1}=\widehat{\mathrm{MP}}_{1}(\boldsymbol{X}, \boldsymbol{A}) \quad\quad\quad\quad(11)$

    $\boldsymbol{H}_{2}=\widehat{\mathrm{MP}}_{2}\left(\boldsymbol{X} \boldsymbol{W}_{\mathrm{skip}}+\boldsymbol{H}_{1}, \boldsymbol{A}\right)\quad\quad\quad\quad(12)$

    $f(\boldsymbol{X}, \boldsymbol{A})=\boldsymbol{H}_{3} =\widehat{\mathrm{MP}}_{3}\left(\boldsymbol{X} \boldsymbol{W}_{\mathrm{skip}}^{\prime}+\boldsymbol{H}_{1}+\boldsymbol{H}_{2}, \boldsymbol{A}\right)\quad\quad\quad\quad(13)$

3.3 Results and Analysis

  

4 Conclusion

  交叉視圖節點一致性對比。

 

修改歷史

2022-03-28 創建文章
2022-06-12 精讀

 

論文解讀目錄


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM