原文链接:https://stackoverflow.com/questions/59027150/keras-training-freezes-during-fit-generator
一般来说,我们可以使用Keras包中的fit函数进行模型的训练,其参数如下:
Model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
其中x和validation_data可以是事先加载好的数据组成的tuple——(inputs, targets),也可以是根据Keras相关API构建的Data Generator(包括ImageDataGenerator、keras.utils.Sequence等),在训练的过程中,这些数据会按设定好的batch_size被喂给模型,从而完成train和evaluate。
当我们在使用generator向模型中输入数据的时候,在部分高版本的Keras(>2.0.0)中可能会出现第一个epoch训练结束,但是evalute过程不结束,表现为第一个epoch卡住的情况。
根据相关资料和笔者自身经验,强烈建议在调用fit函数时,显式地指出step_per_epoch和validation_steps的值,从而解决epoch卡住无法结束的问题
