用Pytorch1.0進行半精度浮點型網絡訓練需要注意下問題:
1、網絡要在GPU上跑,模型和輸入樣本數據都要cuda().half()
2、模型參數轉換為half型,不必索引到每層,直接model.cuda().half()即可
3、對於半精度模型,優化算法,Adam我在使用過程中,在某些參數的梯度為0的時候,更新權重后,梯度為零的權重變成了NAN,這非常奇怪,但是Adam算法對於全精度數據類型卻沒有這個問題。
另外,SGD算法對於半精度和全精度計算均沒有問題。
還有一個問題是不知道是不是網絡結構比較小的原因,使用半精度的訓練速度還沒有全精度快。這個值得后續進一步探索。
對於上面的這個問題,的確是網絡很小的情況下,在1080Ti上半精度浮點型沒有很明顯的優勢,但是當網絡變大之后,半精度浮點型要比全精度浮點型要快。但具體快多少和模型的大小以及輸入樣本大小有關系,我測試的是要快1/6,同時,半精度浮點型在占用內存上比較有優勢,對於精度的影響尚未探究。
將網絡再變大些,epoch的次數也增大,半精度和全精度的時間差就表現出來了,在訓練的時候。