tensorflow-记一次global step相关问题的排查


  在调试tensorflow分布式训练代码时,遇到一个诡异的错误:Global step should be created to use StopAtStepHook.

 

  错误发生在以下代码处:

  

stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.total_steps)
checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir = train_dir, save_steps = 1000, saver = saver)
hooks = [stop_hook, checkpoint_hook]
with tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=None, hooks=hooks):
    while not sess.should_stop():
        _, step, accuracy = sess.run( [optimizer, global_step, accuracy], feed_dict = feed)

  报错显示:RunTimeError: Global step should be created to use StopAtStepHook.

  看似是因为没有定义global_step。但代码里其实定义了,所以我去查看了源码:

  

class StopAtStepHook(session_run_hook.SessionRunHook):
    def begin(self):
        self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use StopAtStepHook.")

  是在StopAtStepHook类的begin方法处出的错。该方法在session使用前被调用,作用就是获取global_step

 

  再去查看training_util._get_or_create_global_step_read(),详细代码就不赘述了,结论是它靠寻找name为“global step”的变量,来寻找global_step,而我之前自己定义的global_step没有起名字。。。

 

  经验教训:像global_step这种变量,最好使用默认的方法:tf.train.get_or_create_global_step()

  


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM