原文鏈接:https://stackoverflow.com/questions/59027150/keras-training-freezes-during-fit-generator
一般來說,我們可以使用Keras包中的fit函數進行模型的訓練,其參數如下:
Model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
其中x和validation_data可以是事先加載好的數據組成的tuple——(inputs, targets),也可以是根據Keras相關API構建的Data Generator(包括ImageDataGenerator、keras.utils.Sequence等),在訓練的過程中,這些數據會按設定好的batch_size被喂給模型,從而完成train和evaluate。
當我們在使用generator向模型中輸入數據的時候,在部分高版本的Keras(>2.0.0)中可能會出現第一個epoch訓練結束,但是evalute過程不結束,表現為第一個epoch卡住的情況。
根據相關資料和筆者自身經驗,強烈建議在調用fit函數時,顯式地指出step_per_epoch和validation_steps的值,從而解決epoch卡住無法結束的問題
