MonitoredTrainingSession是tensorflow管理分布式訓練中一個重要方法,它相當於集成了一些監控訓練組件,如init、summary、log、save等。在早期的版本,一般使用tf.train.Supervisor來管理session,后來框架升級后,官方就推薦用MonitoredTrainingSession了。
一、訓練為什么要管理?
搭建一個簡單的分布式訓練是不需要管理的,只需要定義好ClusterSpec,給每個節點分配Server,建好圖,就可以開始迭代了。最簡單的代碼如下:
import tensorflow as tf ps_hosts = [xx.xx.xx.xx: xxxx] worker_hosts = [xx.xx.xx.xx:xxxx, xx.xx.xx.xx:xxxx] cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": sess = tf.Session() with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): # build_graph() step = 0 while step < FLAGS.total_step: sess.run()
隨着問題和模型的復雜化,我們也許會有監控訓練的需求,如記錄日志、訓練可視化、checkpoint、early-stop、訓練效率調優等,tensorflow提供了大量的工具支持,但這就加重了代碼的復雜度。所以tensorflow封裝了MonitoredTrainingSession,將各種監控訓練的組件外掛到一個類里.
二、MonitoredTrainingSession參數
tf.train.MonitoredTrainingSession( master='', is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=USE_DEFAULT, save_summaries_steps=USE_DEFAULT, save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, save_checkpoint_steps=USE_DEFAULT, summary_dir=None )
args:
master: server.target
is_chief: 是否為chief(一般把task_index=0定為chief)。chief節點會負責初始化和模型restore,其他節點只需等待chief初始化完成
checkpoint_dir: checkpoint文件路徑
scaffold:用於完成圖表
hooks:最重要的參數。它是一個SessionRunHook對象的列表,包含了所有希望外掛的組件,如CheckpointSaverHook、FeedFnHook、LoggerTensorHook、NanTensorHook、ProfileHook、StopAtStepHook等,也可以自定義Hook,只要繼承SessionRunHook類就行。下面會詳細介紹幾個重要Hook
chief_only_hooks:只有chief節點才會生效的hook
save_checkpoint_secs:保存checkpoint的頻率
save_summaries_steps:按步數保存summary的頻率 ;save_summaries_secs是按時間
config:session配置,是ConfigProtoproto格式
實例化后就得到一個MonitoredSession對象,可以當作普通session使用
三、Hook的使用
Hook顧名思義,是一個“外掛”的組件,用於執行訓練中的各種功能。 Hook的基類是tf.train.SessionRunHook,需要實現下面幾個方法: 1.after_create_session(
session,
coord
)
在session被創建后調用
2.after_run(
run_context,
run_values
)
在每次session.run后被調用
3.
幾個常用的內置的Hook如下:
-
tf.train.StopAtStepHook:在一定步數停止。
-
tf.train.CheckpointSaverHook:checkpoint保存
__init__( checkpoint_dir, save_secs=None, save_steps=None, saver=None, checkpoint_basename='model.ckpt', scaffold=None, listeners=None )
參數設置了checkpoint的路徑、保存頻率、saver等
-
tf.train.FeedFnHook:創建feed_dict
__init__(feed_fn)
指定生成feed的函數
-
tf.train.FinalOpsHook:在session結束時的評估操作
__init__( final_ops, final_ops_feed_dict=None )
在訓練結束時,final_ops_feed_dict 喂給final_ops這個tensor,得到final_ops_values。一般用來做測試集的評估
-
tf.train.NanTensorHook:監控loss是否為NAN
__init__( loss_tensor, fail_on_nan_loss=True )
調試和終結訓練用。如果可以正常訓練,建議不用這個Hook,對效率影響比較大
-
tf.train.SummarySaverHook:記錄summary,訓練可視化
__init__( save_steps=None, save_secs=None, output_dir=None, summary_writer=None, scaffold=None, summary_op=None )
給定summary_op,定期輸出。
-
自定義Hook。可以自己實現Hook,只要繼承SessionRunHook,實現幾個方法即可。給一個cifar10中定義LoggerHook的例子:
class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time#duration持續的時間 self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch))
該Hook定制了各種記錄日志的方法
四、總結
MonitoredTrainingSession和Hook的結合使得可以自由組裝訓練過程,配合分布式訓練和tensorboard的使用,可以提高調試效率。
五、參考
https://www.tensorflow.org/deploy/distributed
https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook