在調試tensorflow分布式訓練代碼時,遇到一個詭異的錯誤:Global step should be created to use StopAtStepHook.
錯誤發生在以下代碼處:
stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.total_steps) checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir = train_dir, save_steps = 1000, saver = saver) hooks = [stop_hook, checkpoint_hook] with tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=None, hooks=hooks): while not sess.should_stop(): _, step, accuracy = sess.run( [optimizer, global_step, accuracy], feed_dict = feed)
報錯顯示:RunTimeError: Global step should be created to use StopAtStepHook.
看似是因為沒有定義global_step。但代碼里其實定義了,所以我去查看了源碼:
class StopAtStepHook(session_run_hook.SessionRunHook): def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use StopAtStepHook.")
是在StopAtStepHook類的begin方法處出的錯。該方法在session使用前被調用,作用就是獲取global_step
再去查看training_util._get_or_create_global_step_read(),詳細代碼就不贅述了,結論是它靠尋找name為“global step”的變量,來尋找global_step,而我之前自己定義的global_step沒有起名字。。。
經驗教訓:像global_step這種變量,最好使用默認的方法:tf.train.get_or_create_global_step()