原文鏈接:https://stackoverflow.com/questions/59027150/keras-training-freezes-during-fit-generator
一般來說,我們可以使用Keras
包中的fit
函數進行模型的訓練,其參數如下:
Model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
其中x
和validation_data
可以是事先加載好的數據組成的tuple
——(inputs, targets)
,也可以是根據Keras相關API構建的Data Generator(包括ImageDataGenerator
、keras.utils.Sequence
等),在訓練的過程中,這些數據會按設定好的batch_size被喂給模型,從而完成train和evaluate。
當我們在使用generator向模型中輸入數據的時候,在部分高版本的Keras(>2.0.0)
中可能會出現第一個epoch訓練結束,但是evalute過程不結束,表現為第一個epoch卡住的情況。
根據相關資料和筆者自身經驗,強烈建議在調用fit
函數時,顯式地指出step_per_epoch
和validation_steps
的值,從而解決epoch卡住無法結束的問題