如何解决图神经网络(GNN)训练中过度平滑的问题?
转自知乎
https://www.zhihu.com/question/346942899/answer/835292740
泻药..首先要搞清楚图神经网络不能加深的原因是什么。常见的原因有三种:1)数据集太小,overfitting的问题,在一些数据上training acc为100%的大概率是这个问题,需要通过防止过拟合的技术来解决 2)vanishing gradient,这是CNN里一样存在的问题,当层数太深导致网络的参数不能得到有效的训练。这个问题可以加skip connections可以有效解决 3)over smoothing
其他同学@也提到了我们ICCV Oral的工作:DeepGCNs,这个工作主要是解决了vanishing gradient和over smoothing的问题,最开始是在点云上做的实验,正在做的TPAMI版本我们把14层的图网络MRConv用到了PPI数据,达到了F1 score 99.4的效果,是目前的start-of-the-art。PPI部分的实验代码近期会开源。
点云实验的代码、论文、slides都已开源。论文还有很多可以改善的地方,我们也还在做一些后续工作,欢迎交流:
Arxiv paper:
DeepGCNs: Can GCNs Go as Deep as CNNs?Github:
Tensorflow:
lightaime/deep_gcnsPytorch:
lightaime/deep_gcns_torch
都说GNN实际是个热传导,所以如果导热率太高,时间太长,最终就是温度达到单一温度。所以要降低导热率,或者缩短传导时间,才能形成有局部特征的分布模式。从消息传递的角度,就是要增加势能函数的差异性,或者说是降低系统温度,以及减少消息传递的循环次数。
更正一下题目中的几个小误区:
原题:如何解决 图神经网络(GNN)训练中过度平滑的问题?即在图神经网络的训练过程中,随着网络层数的增加和迭代次数的增加, 每个节点的隐层表征会趋向于收敛到同一个值(即空间上的同一个位置)。
不是所有图神经网络都有 over-smooth 的问题,例如,基于 RandomWalk + RNN、基于 Attention 的模型大多不会有这个问题,是可以放心叠深度的~只有部分图卷积神经网络会有该问题。
不是每个节点的表征都趋向于收敛到同一个值,更准确的说,是同一连通分量内的节点的表征会趋向于收敛到同一个值。这对表征图中不通簇的特征、表征图的特征都有好处。但是,有很多任务的图是连通图,只有一个连通分量,或较少的连通分量,这就导致了节点的表征会趋向于收敛到一个值或几个值的问题。
注:在图论中,无向图的 连通分量是一个子图,其中任何两个顶点通过路径相互连接。
为什么 GCN 中会存在 over-smooth 的问题
首先,回顾一下全连接神经网络和 Kipf 图卷积神经网络的公式:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1GQ04lM0QlNUNzaWdtYSUyOFhXJTI5JTVDJTVDK0dDTiUzRCU1Q3NpZ21hJTI4JTVDaGF0JTdCRCU3RCU1RSU3Qi0lNUNmcmFjJTdCMSU3RCU3QjIlN0QlN0QlNUNoYXQlN0JBJTdEJTVDaGF0JTdCRCU3RCU1RSU3Qi0lNUNmcmFjJTdCMSU3RCU3QjIlN0QlN0RYVyUyOSs=.png)
其中,
为激活函数,
为节点特征,
为训练参数,
,
为邻接矩阵,
,
为图中的所有节点。可以发现图卷积神经网络只多了对节点信息进行汇聚的权重
。从
(无归一化)到
(归一化),再到
(对称归一化),对于该权重的研究已然汗牛充栋。
学有余力的同学可以往下看通式上 over-smooth 的证明,这里先以
为例,进行一个直观的解释:
首先,中间层的
由任务相关的
反向传播进行优化,可以理解为任务相关的模式提取能力,我们将其统一在图卷积后进行,多层卷积公式可以近似为:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1HQ04lM0RmJTI4JTI4JTVDaGF0JTdCRCU3RCU1RSU3Qi0xJTdEJTVDaGF0JTdCQSU3RCUyOSU1RWtYJTI5JTVDJTVD.png)
其中,
可以看作被提取的多个隐藏层。化简该式:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1HQ04lM0RmJTI4JTVDZnJhYyU3QiU1Q2hhdCU3QkElN0QlNUVrJTdEJTdCJTVDaGF0JTdCRCU3RCU1RWslN0RYJTI5JTVDJTVD.png)
其中,邻接矩阵的幂,
表示节点
和节点
之间长度为
的 walk 的数量。而它的度,
代表节点
到所有节点之间长度为
的 walk 的数量。
这时,
则代表以节点
为起点,随机完成
步的 walk 最后抵达节点
的概率。
随着 walk 步数的增多,远距离节点的抵达难度越来越小,被随机选中的概率越来越大。当
时,连通分量中的节点
到达连通分量中任意节点的概率都趋于一致,为
,其中
代表连通分量中节点的总数,即
,其中
、
代表连通分量的邻接矩阵和度矩阵。
令连通分量中的特征向量为
,且
,
代表连通分量中节点的特征维度。节点信息的汇聚可以表示为:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUIlNUNmcmFjJTdCMSU3RCU3Qm4lN0QlNURfJTdCbiU1Q3RpbWVzK24lN0R4XyU3Qm4lNUN0aW1lcytxJTdEJTVDJTVD.png)
连通分量中每个节点的特征都为所有节点特征的平均,也就是我们开始的时候说的,同一连通分量内的节点的表征趋向于收敛到同一个值。
在感性地认识到图卷积与连通分量之间的关联后,有的工作想到利用特征分解(特征向量对应连通分量)给出 over-smooth 定理的证明[1]:
over-smooth 定理:假设图
由
个连通分量
构成,其中第
个连通分量可以用向量
表示:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1sJTVFJTdCJTI4aSUyOSU3RF9qJTNEJTVDbGVmdCU1QyU3QislNUNiZWdpbiU3QmFsaWduZWQlN0QrMSslMjYlMkMrdl9qKyU1Q2luK0NfaSslNUMlNUMrMCslMjYlMkMrdl9pKyU1Q25vdGluK0NfaSslNUMlNUMrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUNyaWdodC4lNUMlNUM=.png)
那么,当图中不存在二分连通分量时,有:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNsaW0rXyU3Qm0rJTVDcmlnaHRhcnJvdyUyQiU1Q2luZnR5JTdEJTVDbGVmdCUyOEktJTVDYWxwaGErTF8lN0JyK3clN0QlNUNyaWdodCUyOSU1RSU3Qm0lN0QrVyUzRCU1Q2xlZnQlNUIlNUNtYXRoYmYlN0IxJTdEJTVFJTdCJTI4MSUyOSU3RCUyQyslNUNtYXRoYmYlN0IxJTdEJTVFJTdCJTI4MiUyOSU3RCUyQyslNUNsZG90cyUyQyslNUNtYXRoYmYlN0IxJTdEJTVFJTdCJTI4ayUyOSU3RCU1Q3JpZ2h0JTVEKyU1Q3RoZXRhXyU3QjElN0QlNUMlNUMrJTVDbGltK18lN0JtKyU1Q3JpZ2h0YXJyb3clMkIlNUNpbmZ0eSU3RCU1Q2xlZnQlMjhJLSU1Q2FscGhhK0xfJTdCcyt5K20lN0QlNUNyaWdodCUyOSU1RSU3Qm0lN0QrVyUzREQlNUUlN0ItJTVDZnJhYyU3QjElN0QlN0IyJTdEJTdEJTVDbGVmdCU1QiU1Q21hdGhiZiU3QjElN0QlNUUlN0IlMjgxJTI5JTdEJTJDKyU1Q21hdGhiZiU3QjElN0QlNUUlN0IlMjgyJTI5JTdEJTJDKyU1Q2xkb3RzJTJDKyU1Q21hdGhiZiU3QjElN0QlNUUlN0IlMjhrJTI5JTdEJTVDcmlnaHQlNUQrJTVDdGhldGFfJTdCMiU3RA==.png)
其中,
和
表示线性组合
的系数,且:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lN0JMXyU3QiU1Q21hdGhybSU3QnJ3JTdEJTdEKyUzREQlNUUlN0ItMSU3RCtMJTNESS1EJTVFJTdCLTElN0QrVyU3RCU1QyU1QyslN0JMXyU3QiU1Q21hdGhybSU3QnN5bSU3RCU3RCslM0REJTVFJTdCLTErJTJGKzIlN0QrTCtEJTVFJTdCLTErJTJGKzIlN0QlM0RJLUQlNUUlN0ItMSslMkYrMiU3RCtXK0QlNUUlN0ItMSslMkYrMiU3RCU3RCslNUMlNUMr.png)
本想写自己的证明过程,但由于篇幅较长喧宾夺主,有机会再贴~
如何解决 over-smooth 的问题
在了解为什么 GCN 中会存在 over-smooth 问题后,剩下的工作就是对症下药了:
问题:图卷积会使同一连通分量内的节点的表征会趋向于收敛到同一个值。
- 针对“图卷积”:在当前任务上,是否能够使用 RNN + RandomWalk(数据为图结构,边已然存在)或是否能够使用 Attention(数据为流形结构,边不存在,但含有隐式的相关关系)?
- 针对“同一连通分量内的节点”:在当前任务上,是否可以对图进行 cut 等预处理?如果可以,将图分为越多的连通分量,over-smooth 就会越不明显。极端情况下,节点都不相互连通,则完全不存在 over-smooth 现象(但也无法获取周围节点的信息)。
如果上述方法均不适用,仍有以下 deeper 和 wider 的措施可以保证 GCN 在过参数化时对模型的训练和拟合不产生负面影响。个人感觉,这类方法的实质是不同深度的 GCN 模型的 ensamble:
巨人肩膀上的模型深度 —— residual 等
Kipf 在提出 GCN 时,就发现了添加更多的卷积层似乎无法提高图模型的效果,并通过试验将其归因于 over-smooth:多层 GCN 可能导致节点趋同化,没有区别性。但是,早期的研究认为这是由 GCN 过分强调了相邻节点的关联而忽视了节点自身的特点导致的。 所以 Kipf 给出的解决方案是添加残差连接[2],将节点自身特点从上一层直接传输到下一层:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1aJTVFJTdCJTI4bCUyQjElMjklN0QlM0QlNUNoYXQlN0JEJTdEJTVFJTdCLSU1Q2ZyYWMlN0IxJTdEJTdCMiU3RCU3RCU1Q2hhdCU3QkElN0QlNUNoYXQlN0JEJTdEJTVFJTdCLSU1Q2ZyYWMlN0IxJTdEJTdCMiU3RCU3RFglNUUlN0IlMjhsJTI5JTdEK1clNUUlN0IlMjhsJTI5JTdEJTJDJTVDJTVDWCU1RSU3QiUyOGwlMkIxJTI5JTdEJTNEJTVDc2lnbWElNUNsZWZ0JTI4WiU1RSU3QiUyOGwlMkIxJTI5JTdEJTVDcmlnaHQlMjklMkJYJTVFJTdCJTI4bCUyOSU3RCU1QyU1Qw==.png)
在这个思路下,陆续有工作借鉴 DenseNet,将 residual 连接替换为 dense 连接,提出了自己的 module [3][4]:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1YJTVFJTdCJTI4bCUyQjElMjklN0QlM0QlNUNtYXRoY2FsJTdCVCU3RCUyOCU1Q3NpZ21hJTVDbGVmdCUyOFolNUUlN0IlMjhsJTJCMSUyOSU3RCU1Q3JpZ2h0JTI5JTJDWCU1RSU3QiUyOGwlMjklN0QlMjklNUMlNUM=.png)
其中,
表示拼接节点的特征向量。
最近,也有些工作认为直接将使用残差连接矫枉过正,残差模块完全忽略了相邻节点的权重,因而选择在
的基础上,对节点自身进行加强[5]:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1YJTVFJTdCJTI4bCUyQjElMjklN0QlM0QlNUNzaWdtYSU1Q2xlZnQlMjglNUNsZWZ0JTI4JTVDaGF0JTdCRCU3RCU1RSU3Qi0lNUNmcmFjJTdCMSU3RCU3QjIlN0QlN0QlNUNoYXQlN0JBJTdEJTVDaGF0JTdCRCU3RCU1RSU3Qi0lNUNmcmFjJTdCMSU3RCU3QjIlN0QlN0QlMkJJJTVDcmlnaHQlMjkrWCU1RSU3QiUyOGwlMjklN0QrVyU1RSU3QiUyOGwlMjklN0QlNUNyaWdodCUyOSU1QyU1Qw==.png)
在此基础上,作者进一步考虑了相邻节点的数量,提出了新的正则化方法:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1YJTVFJTdCJTI4bCUyQjElMjklN0QlM0QlNUNzaWdtYSU1Q2xlZnQlMjglMjglNUNoYXQlN0JEJTdEJTVFJTdCLTElN0QlNUNoYXQlN0JBJTdEJTJCJTVDbGFtYmRhKyU1Q29wZXJhdG9ybmFtZSU3QmRpYWclN0QlMjglNUNoYXQlN0JEJTdEJTVFJTdCLTElN0QlNUNoYXQlN0JBJTdEJTI5K1glNUUlN0IlMjhsJTI5JTdEK1clNUUlN0IlMjhsJTI5JTdEJTVDcmlnaHQlMjklNUMlNUM=.png)
另辟蹊径的模型宽度 —— multi-hops 等
随着图卷积渗透到各个领域,一些研究开始放弃深度上的拓展,选择效仿 Inception 的思路拓宽网络的宽度,通过不同尺度感受野的组合对提高模型对节点的表征能力。N-GCN[6]通过在不同尺度下进行卷积,再融合所有尺度的卷积结果得到节点的特征表示:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1OLUdDTiUzRHNvZnRtYXglMjglNUNtYXRoY2FsJTdCVCU3RCU1Q2xlZnQlMjglNUNiZWdpbiU3Qm1hdHJpeCU3RCtHQ04lMjglNUNiYXIrQSU1RTAlMkNYJTNCJTVDdGhldGElNUUlN0IlMjgwJTI5JTdEJTI5JTVDJTVDK0dDTiUyOCU1Q2JhcitBJTVFMSUyQ1glM0IlNUN0aGV0YSU1RSU3QiUyODElMjklN0QlMjklNUMlNUMrLi4uJTVDJTVDK0dDTiUyOCU1Q2JhcitBJTVFbiUyQ1glM0IlNUN0aGV0YSU1RSU3QiUyOG4lMjklN0QlMjkrJTVDZW5kJTdCbWF0cml4JTdEJTVDcmlnaHQlMjlXJTI5JTVDJTVD.png)
其中,
,
表示拼接节点的特征向量。原文中尝试了
和
等不同的归一化方法对当前节点
阶临域的进行信息汇聚,取得了还不错的效果。
也有一些工作认为 GCN 的各层的卷积结果是一个有序的序列:对于一个
层的 GCN,第
层捕获了
-hop 邻居节点的信息,其中
,相邻层
和
之间有依赖关系。因而,这类方法选择使用 RNN 对各层之间的长期依赖建模[7]:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1YJTVFJTdCJTI4bCUyQjElMjklN0QlM0RSK04rTiU1Q2xlZnQlMjhHQ04lNUNsZWZ0JTI4WCU1RSU3QiUyOGwlMjklN0QlMkMrQSslM0IrJTVDdGhldGElNUUlN0IlMjhsJTI5JTdEJTVDcmlnaHQlMjklMkMrWCU1RSU3QiUyOGwlMjklN0QlNUNyaWdodC4lMjklNUMlNUM=.png)
即为:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrWCU1RSU3QiUyOGwlMkIxJTI5JTdEKyUyNiUzRCU1Q2hhdCU3QkQlN0QlNUUlN0ItJTVDZnJhYyU3QjElN0QlN0IyJTdEJTdEKyU1Q2hhdCU3QkElN0QrJTVDaGF0JTdCRCU3RCU1RSU3Qi0lNUNmcmFjJTdCMSU3RCU3QjIlN0QlN0QrWCU1RSU3QiUyOGwlMjklN0QrVyU1RSU3QiUyOGwlMjklN0QrJTVDJTVDK0klNUUlN0IlMjhsJTJCMSUyOSU3RCslMjYlM0QlNUNzaWdtYSU1Q2xlZnQlMjhYJTVFJTdCJTI4bCUyQjElMjklN0QrV18lN0JpJTdEJTJCWCU1RSU3QiUyOGwlMjklN0QrVV8lN0JpJTdEJTJCJTVDbGVmdCU1QmJfJTdCaSU3RCU1Q3JpZ2h0JTVEXyU3Qk4lN0QlNUNyaWdodCUyOSslNUMlNUMrRiU1RSU3QiUyOGwlMkIxJTI5JTdEKyUyNiUzRCU1Q3NpZ21hJTVDbGVmdCUyOFglNUUlN0IlMjhsJTJCMSUyOSU3RCtXXyU3QmYlN0QlMkJYJTVFJTdCJTI4bCUyOSU3RCtVXyU3QmYlN0QlMkIlNUNsZWZ0JTVCYl8lN0JmJTdEJTVDcmlnaHQlNURfJTdCTiU3RCU1Q3JpZ2h0JTI5KyU1QyU1QytPJTVFJTdCJTI4bCUyQjElMjklN0QrJTI2JTNEJTVDc2lnbWElNUNsZWZ0JTI4WCU1RSU3QiUyOGwlMkIxJTI5JTdEK1dfJTdCbyU3RCUyQlglNUUlN0IlMjhsJTI5JTdEK1VfJTdCbyU3RCUyQiU1Q2xlZnQlNUJiXyU3Qm8lN0QlNUNyaWdodCU1RF8lN0JOJTdEJTVDcmlnaHQlMjkrJTVDJTVDKyU1Q2hhdCU3QkMlN0QlNUUlN0IlMjhsJTJCMSUyOSU3RCslMjYlM0QlNUN0YW5oKyU1Q2xlZnQlMjhYJTVFJTdCJTI4bCUyQjElMjklN0QrV18lN0JjJTdEJTJCWCU1RSU3QiUyOGwlMjklN0QrVV8lN0JjJTdEJTJCJTVDbGVmdCU1QmJfJTdCYyU3RCU1Q3JpZ2h0JTVEXyU3Qk4lN0QlNUNyaWdodCUyOSslNUMlNUMrQyU1RSU3QiUyOGwlMkIxJTI5JTdEKyUyNiUzREYlNUUlN0IlMjhsJTJCMSUyOSU3RCslNUNjaXJjK0MlNUUlN0JsJTdEJTJCSSU1RSU3QiUyOGwlMkIxJTI5JTdEKyU1Q2NpcmMrJTVDaGF0JTdCQyU3RCU1RSU3QiUyOGwlMkIxJTI5JTdEKyU1QyU1QytYJTVFJTdCJTI4bCUyQjElMjklN0QrJTI2JTNETyU1RSU3QiUyOGwlMkIxJTI5JTdEKyU1Q2NpcmMrJTVDdGFuaCslNUNsZWZ0JTI4QyU1RSU3QiUyOGwlMkIxJTI5JTdEJTVDcmlnaHQlMjkrJTVDZW5kJTdCYWxpZ25lZCU3RCU1QyU1Qw==.png)
随着图卷积的日益成熟,深层的图卷积已经在各个领域开花结果啦~ 相信在不久的将来,pruning 和 NAS 还会碰撞出新的火花,童鞋们加油呀!另外,有的同学私信想看我的论文中是怎样处理 over-smooth 的~可是由于写作技巧太差我的论文还没发粗去(最开始导师都看不懂我写的是啥,感谢一路走来没有放弃我的导师和师兄,现在已经勉强能看了),等以后有机会再分享叭~

