如何解決圖神經網絡(GNN)訓練中過度平滑的問題?
轉自知乎
https://www.zhihu.com/question/346942899/answer/835292740
瀉葯..首先要搞清楚圖神經網絡不能加深的原因是什么。常見的原因有三種:1)數據集太小,overfitting的問題,在一些數據上training acc為100%的大概率是這個問題,需要通過防止過擬合的技術來解決 2)vanishing gradient,這是CNN里一樣存在的問題,當層數太深導致網絡的參數不能得到有效的訓練。這個問題可以加skip connections可以有效解決 3)over smoothing
其他同學@也提到了我們ICCV Oral的工作:DeepGCNs,這個工作主要是解決了vanishing gradient和over smoothing的問題,最開始是在點雲上做的實驗,正在做的TPAMI版本我們把14層的圖網絡MRConv用到了PPI數據,達到了F1 score 99.4的效果,是目前的start-of-the-art。PPI部分的實驗代碼近期會開源。
點雲實驗的代碼、論文、slides都已開源。論文還有很多可以改善的地方,我們也還在做一些后續工作,歡迎交流:
Arxiv paper:
DeepGCNs: Can GCNs Go as Deep as CNNs?Github:
Tensorflow:
lightaime/deep_gcnsPytorch:
lightaime/deep_gcns_torch
都說GNN實際是個熱傳導,所以如果導熱率太高,時間太長,最終就是溫度達到單一溫度。所以要降低導熱率,或者縮短傳導時間,才能形成有局部特征的分布模式。從消息傳遞的角度,就是要增加勢能函數的差異性,或者說是降低系統溫度,以及減少消息傳遞的循環次數。
更正一下題目中的幾個小誤區:
原題:如何解決 圖神經網絡(GNN)訓練中過度平滑的問題?即在圖神經網絡的訓練過程中,隨着網絡層數的增加和迭代次數的增加, 每個節點的隱層表征會趨向於收斂到同一個值(即空間上的同一個位置)。
不是所有圖神經網絡都有 over-smooth 的問題,例如,基於 RandomWalk + RNN、基於 Attention 的模型大多不會有這個問題,是可以放心疊深度的~只有部分圖卷積神經網絡會有該問題。
不是每個節點的表征都趨向於收斂到同一個值,更准確的說,是同一連通分量內的節點的表征會趨向於收斂到同一個值。這對表征圖中不通簇的特征、表征圖的特征都有好處。但是,有很多任務的圖是連通圖,只有一個連通分量,或較少的連通分量,這就導致了節點的表征會趨向於收斂到一個值或幾個值的問題。
注:在圖論中,無向圖的 連通分量是一個子圖,其中任何兩個頂點通過路徑相互連接。
為什么 GCN 中會存在 over-smooth 的問題
首先,回顧一下全連接神經網絡和 Kipf 圖卷積神經網絡的公式:
其中, 為激活函數,
為節點特征,
為訓練參數,
,
為鄰接矩陣,
,
為圖中的所有節點。可以發現圖卷積神經網絡只多了對節點信息進行匯聚的權重
。從
(無歸一化)到
(歸一化),再到
(對稱歸一化),對於該權重的研究已然汗牛充棟。
學有余力的同學可以往下看通式上 over-smooth 的證明,這里先以 為例,進行一個直觀的解釋:
首先,中間層的 由任務相關的
反向傳播進行優化,可以理解為任務相關的模式提取能力,我們將其統一在圖卷積后進行,多層卷積公式可以近似為:
其中, 可以看作被提取的多個隱藏層。化簡該式:
其中,鄰接矩陣的冪, 表示節點
和節點
之間長度為
的 walk 的數量。而它的度,
代表節點
到所有節點之間長度為
的 walk 的數量。
這時, 則代表以節點
為起點,隨機完成
步的 walk 最后抵達節點
的概率。
隨着 walk 步數的增多,遠距離節點的抵達難度越來越小,被隨機選中的概率越來越大。當 時,連通分量中的節點
到達連通分量中任意節點的概率都趨於一致,為
,其中
代表連通分量中節點的總數,即
,其中
、
代表連通分量的鄰接矩陣和度矩陣。
令連通分量中的特征向量為 ,且
,
代表連通分量中節點的特征維度。節點信息的匯聚可以表示為:
連通分量中每個節點的特征都為所有節點特征的平均,也就是我們開始的時候說的,同一連通分量內的節點的表征趨向於收斂到同一個值。
在感性地認識到圖卷積與連通分量之間的關聯后,有的工作想到利用特征分解(特征向量對應連通分量)給出 over-smooth 定理的證明[1]:
over-smooth 定理:假設圖 由
個連通分量
構成,其中第
個連通分量可以用向量
表示:
那么,當圖中不存在二分連通分量時,有:
其中, 和
表示線性組合
的系數,且:
本想寫自己的證明過程,但由於篇幅較長喧賓奪主,有機會再貼~
如何解決 over-smooth 的問題
在了解為什么 GCN 中會存在 over-smooth 問題后,剩下的工作就是對症下葯了:
問題:圖卷積會使同一連通分量內的節點的表征會趨向於收斂到同一個值。
- 針對“圖卷積”:在當前任務上,是否能夠使用 RNN + RandomWalk(數據為圖結構,邊已然存在)或是否能夠使用 Attention(數據為流形結構,邊不存在,但含有隱式的相關關系)?
- 針對“同一連通分量內的節點”:在當前任務上,是否可以對圖進行 cut 等預處理?如果可以,將圖分為越多的連通分量,over-smooth 就會越不明顯。極端情況下,節點都不相互連通,則完全不存在 over-smooth 現象(但也無法獲取周圍節點的信息)。
如果上述方法均不適用,仍有以下 deeper 和 wider 的措施可以保證 GCN 在過參數化時對模型的訓練和擬合不產生負面影響。個人感覺,這類方法的實質是不同深度的 GCN 模型的 ensamble:
巨人肩膀上的模型深度 —— residual 等
Kipf 在提出 GCN 時,就發現了添加更多的卷積層似乎無法提高圖模型的效果,並通過試驗將其歸因於 over-smooth:多層 GCN 可能導致節點趨同化,沒有區別性。但是,早期的研究認為這是由 GCN 過分強調了相鄰節點的關聯而忽視了節點自身的特點導致的。 所以 Kipf 給出的解決方案是添加殘差連接[2],將節點自身特點從上一層直接傳輸到下一層:
在這個思路下,陸續有工作借鑒 DenseNet,將 residual 連接替換為 dense 連接,提出了自己的 module [3][4]:
其中, 表示拼接節點的特征向量。
最近,也有些工作認為直接將使用殘差連接矯枉過正,殘差模塊完全忽略了相鄰節點的權重,因而選擇在 的基礎上,對節點自身進行加強[5]:
在此基礎上,作者進一步考慮了相鄰節點的數量,提出了新的正則化方法:
另辟蹊徑的模型寬度 —— multi-hops 等
隨着圖卷積滲透到各個領域,一些研究開始放棄深度上的拓展,選擇效仿 Inception 的思路拓寬網絡的寬度,通過不同尺度感受野的組合對提高模型對節點的表征能力。N-GCN[6]通過在不同尺度下進行卷積,再融合所有尺度的卷積結果得到節點的特征表示:
其中, ,
表示拼接節點的特征向量。原文中嘗試了
和
等不同的歸一化方法對當前節點
階臨域的進行信息匯聚,取得了還不錯的效果。
也有一些工作認為 GCN 的各層的卷積結果是一個有序的序列:對於一個 層的 GCN,第
層捕獲了
-hop 鄰居節點的信息,其中
,相鄰層
和
之間有依賴關系。因而,這類方法選擇使用 RNN 對各層之間的長期依賴建模[7]:
即為:
隨着圖卷積的日益成熟,深層的圖卷積已經在各個領域開花結果啦~ 相信在不久的將來,pruning 和 NAS 還會碰撞出新的火花,童鞋們加油呀!另外,有的同學私信想看我的論文中是怎樣處理 over-smooth 的~可是由於寫作技巧太差我的論文還沒發粗去(最開始導師都看不懂我寫的是啥,感謝一路走來沒有放棄我的導師和師兄,現在已經勉強能看了),等以后有機會再分享叭~