keras的Model支持兩種模式的訓練:
- 直接傳入數組,最終會調用train_array.py中的fit_loop()函數
- 直接傳入生成器,最終會調用train_generator.py中的fit_generator()函數
train_array.py文件只有三個函數,就是fit_loop()、predict_loop()、evaluate_loop()
train_generator.py文件只有三個函數,就是fit_generator()函數、predict_generator()、evaluate_generator()
這兩個函數非常相似,參數也都差不多。它們都接受一個model參數,它的類型是一個Model實例。因為這兩個函數比較重量級,所以從Model中拆分出來單獨作為一個文件,這兩個文件可以說是keras的心臟。這兩個文件把模型、損失、評測指標、回調等組件結合起來,仿佛各條小溪在此處匯聚成大河,仿佛各個樂器在此處齊鳴奏出交響樂。只要看懂這兩個文件,keras可以說是懂了半壁江山。以這兩個文件為出發點,順藤摸瓜按圖索驥就能夠把keras的各個模塊、各個部件理得很清。
fit_loop()有steps_per_epoch和batch_size兩個參數,但是這兩個參數不能同時指定。因為fit_loop傳入了全部數據,所以樣本總數是確定的。steps_per_epoch*batch_size
應該近似等於樣本總數。也就是說steps_per_epoch和batch_size這兩個變量在樣本總數已知的情況下可以互相推出。
fit_loop()函數適用場景包括:
- 樣本數可以全部加載到內存
- 樣本長度統一且固定
fit_generator()相比fit_loop()要靈活很多,但是用起來卻需要額外的步驟。它需要提供一個生成器,這個生成器應該是一個無窮無盡的生成器,也就是說它應該始終源源不斷地產生數據,通過steps_per_epoch來指明每個輪次包含的樣本數。對於生成器每次返回數據,樣本個數即為batch_size。生成器每次返回的數據batch_size是可以參差不齊的。