在调试tensorflow分布式训练代码时,遇到一个诡异的错误:Global step should be created to use StopAtStepHook.
错误发生在以下代码处:
stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.total_steps) checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir = train_dir, save_steps = 1000, saver = saver) hooks = [stop_hook, checkpoint_hook] with tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=None, hooks=hooks): while not sess.should_stop(): _, step, accuracy = sess.run( [optimizer, global_step, accuracy], feed_dict = feed)
报错显示:RunTimeError: Global step should be created to use StopAtStepHook.
看似是因为没有定义global_step。但代码里其实定义了,所以我去查看了源码:
class StopAtStepHook(session_run_hook.SessionRunHook): def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use StopAtStepHook.")
是在StopAtStepHook类的begin方法处出的错。该方法在session使用前被调用,作用就是获取global_step
再去查看training_util._get_or_create_global_step_read(),详细代码就不赘述了,结论是它靠寻找name为“global step”的变量,来寻找global_step,而我之前自己定义的global_step没有起名字。。。
经验教训:像global_step这种变量,最好使用默认的方法:tf.train.get_or_create_global_step()