1. torch.autograd.detect_anomaly()
轉自點擊 ,
import torch # 正向傳播時:開啟自動求導的異常偵測 torch.autograd.set_detect_anomaly(True) # 反向傳播時:在求導時開啟偵測 with torch.autograd.detect_anomaly(): loss.backward()
上面的代碼就會給出具體是哪句代碼求導出現的問題。
2.Debug
https://medium.com/@me_26124/debugging-neural-networks-6fa65742efd
- 通常在使用sqrt/exp的時候會出現非常大或非常小的數,從而導致溢出或者是除0,從而出現Nan值。