tensorflow-記一次global step相關問題的排查


  在調試tensorflow分布式訓練代碼時,遇到一個詭異的錯誤:Global step should be created to use StopAtStepHook.

 

  錯誤發生在以下代碼處:

  

stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.total_steps)
checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir = train_dir, save_steps = 1000, saver = saver)
hooks = [stop_hook, checkpoint_hook]
with tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=None, hooks=hooks):
    while not sess.should_stop():
        _, step, accuracy = sess.run( [optimizer, global_step, accuracy], feed_dict = feed)

  報錯顯示:RunTimeError: Global step should be created to use StopAtStepHook.

  看似是因為沒有定義global_step。但代碼里其實定義了,所以我去查看了源碼:

  

class StopAtStepHook(session_run_hook.SessionRunHook):
    def begin(self):
        self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use StopAtStepHook.")

  是在StopAtStepHook類的begin方法處出的錯。該方法在session使用前被調用,作用就是獲取global_step

 

  再去查看training_util._get_or_create_global_step_read(),詳細代碼就不贅述了,結論是它靠尋找name為“global step”的變量,來尋找global_step,而我之前自己定義的global_step沒有起名字。。。

 

  經驗教訓:像global_step這種變量,最好使用默認的方法:tf.train.get_or_create_global_step()

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM