損失函數是通過keras已經封裝好的函數進行的線性組合, 如下:
def spares_mse_mae_2scc(y_true, y_pred):
return mean_squared_error(y_true, y_pred) + categorical_crossentropy(y_true, y_pred) + 2 * mean_absolute_error(y_true, y_pred)
在訓練的過程中出現Nan, 發現是因為使用categorical_crossentropy(交叉熵)函數是0出現在了log的位置, 是的出現log(0)的情況出現.
可能的原因:
1 學習率的原因, 可以適當降低學習率,並設置學習率衰減;
2 BatchNormlization原因, 可能在正則化的過程中出現大量的0.
3 數據不干凈
我所遇到的問題基本排除上面三種, 我的解決方法:
def mse_mae_2bcc(y_true, y_pred):
return mean_squared_error(y_true, y_pred) + binary_crossentropy(y_true, y_pred) + 2 * mean_absolute_error(y_true, y_pred)
這樣定義損失函數就可以直接避免這個問題, 原因還不太清楚, 有時間推導一下在補充.
參考:
1. https://stackoverflow.com/questions/33712178/tensorflow-nan-bug
2. https://oldpan.me/archives/careful-train-loss-nan-inf
3. https://blog.csdn.net/hahajinbu/article/details/84035486