鄭重聲明:原文參見標題,如有侵權,請聯系作者,將會撤銷發布!
35th Conference on Neural Information Processing Systems (NeurIPS 2021), Sydney, Australia. (同組工作)
Abstract
由於離散二元激活和復雜的時空動態,深度脈沖神經網絡(SNN)給基於梯度的方法帶來了優化困難。考慮到ResNet在深度學習中的巨大成功,使用殘差學習訓練深度SNN是很自然的。之前的Spiking ResNet模仿了ANN中的標准殘差塊,簡單地將ReLU激活層替換為脈沖神經元,存在退化問題,難以實現殘差學習。在本文中,我們提出了spike-element-wise (SEW) ResNet來實現深度SNN中的殘差學習。我們證明SEW ResNet可以輕松實現身份映射並克服Spiking ResNet的梯度消失/爆炸問題。我們在ImageNet、DVS Gesture和CIFAR10-DVS數據集上評估了我們的SEW ResNet,並表明SEW ResNet在准確性和時間步長上都優於最先進的直接訓練的SNN。此外,SEW ResNet可以通過簡單地添加更多層來獲得更高的性能,為訓練深度SNN提供了一種簡單的方法。據我們所知,這是第一次直接訓練超過100層的深度SNN成為可能。我們的代碼可在https://github.com/fangwei123456/Spike-Element-Wise-ResNet獲得。
1 Introduction
人工神經網絡(ANN)在許多任務中取得了巨大成功,包括圖像分類[28, 52, 55]、對象檢測[9, 34, 44]、機器翻譯[2]和游戲[37, 51]。ANN成功的關鍵因素之一是深度學習[29],它使用多層來學習具有多個抽象級別的數據表征。已經證明,較深的網絡在計算成本和泛化能力方面優於較淺的網絡[3]。由深度網絡表示的函數可能需要具有一個隱藏層的淺層網絡的指數數量的隱藏單元[38]。此外,網絡的深度與網絡在實際任務中的表現密切相關[52, 55, 27, 52]。然而,最近的證據[13, 53, 14]表明,隨着網絡深度的增加,准確度會飽和,然后迅速下降。為了解決這個退化問題,殘差學習[14, 15]被提出,並且殘差結構在"非常深"的網絡中被廣泛利用,實現了領先的性能[22, 59, 18, 57]。
脈沖神經網絡(SNN)被認為是ANN的潛在競爭對手,因為它們具有高生物合理性、事件驅動特性和低功耗[45]。最近,深度學習方法被引入到SNN中,並且深度SNN在一些簡單的分類數據集[56]中取得了與ANN相近的性能,但在復雜任務中仍然比ANN差,例如對ImageNet數據集進行分類[47]。為了獲得更高性能的 SNN,自然會探索更深的網絡結構,如ResNet。Spiking ResNet[25, 60, 21, 17, 49, 12, 30, 64, 48, 42, 43]作為ResNet的脈沖版本,是通過模仿ANNs中的殘差塊並用脈沖神經元替換ReLU激活層而提出的。從ANN轉換而來的Spiking ResNet在幾乎所有數據集上都實現了最先進的准確性,而直接訓練的Spiking ResNet尚未經過驗證可以解決退化問題。
在本文中,我們表明Spiking ResNet不適用於所有神經元模型來實現身份映射。即使滿足恆等映射條件,Spiking ResNet也存在梯度消失/爆炸的問題。因此,我們提出了Spike-Element-Wise (SEW) ResNet來實現SNN中的殘差學習。我們證明SEW ResNet可以輕松實現身份映射並同時克服梯度消失/爆炸問題。我們在靜態ImageNet數據集和神經形態DVS手勢數據集[1]、CIFAR10-DVS數據集[32]上評估了Spiking ResNet和SEW ResNet。實驗結果與我們的分析一致,表明較深的Spiking ResNet存在退化問題——較深的網絡比較淺的網絡具有更高的訓練損失,而SEW ResNet可以通過簡單地增加網絡深度來獲得更高的性能。此外,我們表明SEW ResNet在准確性和時間步長上都優於最先進的直接訓練的SNN。據我們所知,這是第一次探索直接訓練的超過100層的深度SNN。
2 Related Work
2.1 Learning Methods of Spiking Neural Networks
ANN到SNN的轉換(ANN2SNN)[20, 4, 46, 49, 12, 11, 6, 54, 33]和具有替代梯度的反向傳播[40]是獲得深度SNN的兩種主要方法。ANN2SNN方法首先用ReLU激活訓練ANN,然后通過用脈沖神經元替換ReLU並添加縮放操作(如權重歸一化和閾值平衡)將ANN轉換為SNN。最近的一些轉換方法已經使用VGG-16和ResNet[12, 11, 6, 33]實現了接近無損的精度。然而,轉換后的SNN需要更長的時間才能在精度上與原始ANN相媲美,因為轉換基於發放率編碼[46],這增加了SNN的延遲並限制了實際應用。反向傳播方法可以分為兩類[26]。第一類中的方法通過在模擬時間步長 [31, 19, 58, 50, 30, 40]上展開網絡來計算梯度,這類似於時間反向傳播(BPTT)的思想。由於與閾值觸發發放相關的梯度是不可微的,因此經常使用替代梯度。由替代方法訓練的SNN不僅限於發放率編碼,還可以應用於時間任務,例如對神經形態數據集進行分類[58, 8, 16]。第二種方法計算現有脈沖時間相對於脈沖時間的膜電位的梯度[5, 39, 24, 65, 63]。
2.2 Spiking Residual Structure
3 Methods
3.1 Spiking Neuron Model
3.2 Drawbacks of Spiking ResNet
3.3 Spike-Element-Wise ResNet
4 Experiments
4.1 ImageNet Classification
4.2 DVS Gesture Classification
4.3 CIFAR10-DVS Classification
5 Conclusion
A Appendix
A.1 Hyper-Parameters
A.2 Random Temporal Delete
A.3 Firing rates on DVS Gesture
A.4 Gradients in Spiking ResNet with Firing Rates
A.5 0/1 Gradients Experiments
A.6 Reproducibility