tensorflow-MonitoredTrainingSession解讀


  MonitoredTrainingSession是tensorflow管理分布式訓練中一個重要方法,它相當於集成了一些監控訓練組件,如init、summary、log、save等。在早期的版本,一般使用tf.train.Supervisor來管理session,后來框架升級后,官方就推薦用MonitoredTrainingSession了。

一、訓練為什么要管理?

  搭建一個簡單的分布式訓練是不需要管理的,只需要定義好ClusterSpec,給每個節點分配Server,建好圖,就可以開始迭代了。最簡單的代碼如下:

  

import tensorflow as tf

ps_hosts = [xx.xx.xx.xx: xxxx]
worker_hosts = [xx.xx.xx.xx:xxxx, xx.xx.xx.xx:xxxx]

cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":

sess = tf.Session()
with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):
# build_graph()
    step = 0
    while step < FLAGS.total_step:
        sess.run()

 

  隨着問題和模型的復雜化,我們也許會有監控訓練的需求,如記錄日志、訓練可視化、checkpoint、early-stop、訓練效率調優等,tensorflow提供了大量的工具支持,但這就加重了代碼的復雜度。所以tensorflow封裝了MonitoredTrainingSession,將各種監控訓練的組件外掛到一個類里.

 

二、MonitoredTrainingSession參數

tf.train.MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=USE_DEFAULT,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200,
    save_checkpoint_steps=USE_DEFAULT,
    summary_dir=None
)

  args:

    master: server.target

    is_chief: 是否為chief(一般把task_index=0定為chief)。chief節點會負責初始化和模型restore,其他節點只需等待chief初始化完成

    checkpoint_dir: checkpoint文件路徑

    scaffold:用於完成圖表

    hooks:最重要的參數。它是一個SessionRunHook對象的列表,包含了所有希望外掛的組件,如CheckpointSaverHook、FeedFnHook、LoggerTensorHook、NanTensorHook、ProfileHook、StopAtStepHook等,也可以自定義Hook,只要繼承SessionRunHook類就行。下面會詳細介紹幾個重要Hook

    chief_only_hooks:只有chief節點才會生效的hook

    save_checkpoint_secs:保存checkpoint的頻率

    save_summaries_steps:按步數保存summary的頻率 ;save_summaries_secs是按時間

    config:session配置,是ConfigProtoproto格式

 

  實例化后就得到一個MonitoredSession對象,可以當作普通session使用
 

三、Hook的使用

  Hook顧名思義,是一個“外掛”的組件,用於執行訓練中的各種功能。
  Hook的基類是tf.train.SessionRunHook,需要實現下面幾個方法:
  1. 
  after_create_session(
    session,
    coord
)  
  
  在session被創建后調用
  2.

after_run(
    run_context,
    run_values
)
  在每次session.run后被調用
  3.
    before_run(run_context)
    每次run前調用
   4. 
    begin()
    調用后,圖就不能再修改
   5. 
     end(session)
    結束session

 幾個常用的內置的Hook如下:
  1. tf.train.StopAtStepHook:在一定步數停止。

    __init__(
        num_steps=None,
        last_step=None
    )

    兩個參數只能設一個,num_steps是執行步數,last_step是終止步數。


  2. tf.train.CheckpointSaverHook:checkpoint保存

    __init__(
        checkpoint_dir,
        save_secs=None,
        save_steps=None,
        saver=None,
        checkpoint_basename='model.ckpt',
        scaffold=None,
        listeners=None
    )

    參數設置了checkpoint的路徑、保存頻率、saver等

  3. tf.train.FeedFnHook:創建feed_dict

    __init__(feed_fn)

    指定生成feed的函數

  4. tf.train.FinalOpsHook:在session結束時的評估操作

    __init__(
        final_ops,
        final_ops_feed_dict=None
    )

    在訓練結束時,final_ops_feed_dict 喂給final_ops這個tensor,得到final_ops_values。一般用來做測試集的評估

  5. tf.train.NanTensorHook:監控loss是否為NAN

    __init__(
        loss_tensor,
        fail_on_nan_loss=True
    )

    調試和終結訓練用。如果可以正常訓練,建議不用這個Hook,對效率影響比較大

  6. tf.train.SummarySaverHook:記錄summary,訓練可視化

    __init__(
        save_steps=None,
        save_secs=None,
        output_dir=None,
        summary_writer=None,
        scaffold=None,
        summary_op=None
    )

    給定summary_op,定期輸出。

  7. 自定義Hook。可以自己實現Hook,只要繼承SessionRunHook,實現幾個方法即可。給一個cifar10中定義LoggerHook的例子:

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""
    
      def begin(self):
        self._step = -1
        self._start_time = time.time()
    
      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.
    
      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time#duration持續的時間
          self._start_time = current_time
    
          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)
    
          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    該Hook定制了各種記錄日志的方法 

 

 四、總結

  MonitoredTrainingSession和Hook的結合使得可以自由組裝訓練過程,配合分布式訓練和tensorboard的使用,可以提高調試效率。

 

五、參考

  https://www.tensorflow.org/deploy/distributed

  https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession

  https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook

  

 

 

 

 
  

    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM