神經切線核(NTK)基本介紹
Neural Tangent Kernel (NTK)是描述無限寬深度神經網絡在梯度下降過程中演化的核。最早由Jacot et al.[1] 在2018年發表於NIPS上。
NTK用來描述神經網絡參數的訓練過程。雖然NTK在初始化時是隨機的,並且在訓練過程中會發生變化,但在無限寬網絡中,NTK會收斂到顯式限制核,並且在訓練過程中保持不變。
NTK的直觀理解[2]
簡單例子
假設有一個函數定義在-10和20之間,函數\(f(i)\)在每一點\(i\)上的值可以被一個離散參數描述\(\theta_i=f(i)\),初始化為\(\theta_i=3i+2\).

我們觀察一個數據點\((x,y)=(10,50)\),顯示為圖中的藍色的×。之后使用一個step梯度下降更新\(\theta\),使用squared error loss function,學習率\(\eta=0.1\)。由於函數在\(x=10\)點只取決於參數\(\theta_{10}\),所以只有這個參數會被更新,其與參數保持不變。紅色箭頭顯示函數在一個step梯度下降中的移動方式。
線性函數
考慮一個線性函數\(f(x,\theta)=\theta_1x+\theta_2\),並將參數初始化為\(\theta_1=3,\theta_2=1\)。使用點\((x,y)=(10,50)\)更新參數,紅色箭頭顯示函數移動方式:

一次梯度下降的計算過程:
NTK
給定參數為\(\theta\)的函數\(f_\theta(x)\),其NTK \(k_\theta(x,x^\prime)\)衡量函數在加入一個新的觀測\(x^{\prime}\)時,使用無限小的梯度步長,參數值\(\theta\)在點\(x\)處變化大小。或者說,\(k(x,x^\prime)\)度量了函數在\(x\)處對\(x^\prime\)的預測誤差的敏感度。
Loss function的結果告訴我們需要增加函數值\(f_\theta(x^\prime)\),反向傳播得到需要對參數\(\theta\)改變多少。然而移動\(f_\theta(x^\prime)\)同時也會移動\(f_\theta(x)\)。\(\tilde{k}_{\theta}\left(x, x^{\prime}\right)\)表達了改變的大小。
對函數\(f\)使用一階泰勒展開處理\(f_\theta(x)\)得到:
重參數化的線性模型
將線性模型重新參數化為\(f_\theta(x)=\theta_1x+10\cdot \theta_2\),參數取值為\(\theta_1=3, \theta_2=0.1\),函數的種類和原來一樣,單步梯度下降時的效果變成了如圖所示:

可以看到函數參數的變化情況與之前有很大變化,可以看出NTK對於函數的重參數化很敏感。
微小徑向基函數網絡(tiny raidal basis function network)
之后使用一個非線性模型:
其中參數值\((\theta_1, \theta_2,\theta_3,\theta_4,\theta_5)=(4.0,-10.0, 25.0, 10.0, 50.0)\).

將\(\tilde{k}_{\theta}\left(x, 10\right)\)除以\(\tilde{k}_{\theta}\left(10, 10\right)\)進行歸一化,並進行可視化,得到圖如下:

這個圖看起來很想一個核函數,可以看到在\(x=7\)的時候,這個核函數取得最大值,說明函數\(f(7)\)的變化更多。
NTK在訓練的過程中是會改變的:

可以看到隨着訓練的進行,NTK變得越來越平坦。
那么NTK能干什么?
NTK在研究無限寬度的網絡的訓練過程就變得十分有用。
- 如果我們選擇合適的分布對網絡參數\(\theta_0\)隨機初始化,隨着寬度的增加,網絡的初始NTK\(k_{\theta_0}\)接近一個確定的核,說明與初始化無關。
- 在無限寬網絡中,\(k_{\theta_t}\)不會隨着\(\theta_t\)的優化而變化,消除了訓練期間的參數依賴性。
- Xiao et al.[3]研究了NTK的譜(spectrum)控制了網絡的可訓練性和通用性,但NTK常因寬度不足、使用較大的學習率或參數改變無法和神經網絡建立對應關系。
那么有限寬網絡
雖然NTK在無限寬網絡中有很好的性質,但在有限寬網絡中就會失效,而且實際中根本沒有無限寬網絡。所以NTK在有限寬網絡中有什么好的性質也很重要。
Lee et al.[4]發現對於足夠寬的深度神經網絡,學習動態會大大簡化,並且在無窮寬的條件下,網絡由在初始參數處的一階泰勒展開式線性模型主導。作者通過實驗發現對於有限寬度的神經網絡,由神經網絡得到的估計與線性模型得到的估計也是基本一致的,且這個一致性對於不同優化方法、不同損失函數都是成立的。所以可能NTK對於寬網絡的很多結論,對於有限寬網絡也是部分成立的。
結論
NTK是描述無限寬深度神經網絡在梯度下降過程中演化的核,用來描述神經網絡的訓練過程。體現的是在無限小步長下,神經網絡在某個數據點\(x^\prime\)處的觀測進行優化,網絡參數在另一點\(x\)處的變化大小。
NTK能夠衡量足夠寬的網絡的可訓練性和通用性,但是NTK在寬網絡下的結論,對於有限寬網絡,並不完全適用,那么如何更好的衡量有限寬網絡的能力?
Jacot A, Gabriel F, Hongler C. Neural tangent kernel: Convergence and generalization in neural networks[J]. arXiv preprint arXiv:1806.07572, 2018. ↩︎
Xiao L, Pennington J, Schoenholz S. Disentangling trainability and generalization in deep neural networks[C]//International Conference on Machine Learning. PMLR, 2020: 10462-10472. ↩︎
Lee J, Xiao L, Schoenholz S, et al. Wide neural networks of any depth evolve as linear models under gradient descent[J]. Advances in neural information processing systems, 2019, 32: 8572-8583. ↩︎
