1、知識點
""" 單機多卡:一台服務器上多台設備(GPU) 參數服務器:更新參數,保存參數 工作服務器:主要功能是去計算 更新參數的模式: 1、同步模型更新 2、異步模型更新 工作服務器會默認一個機器作為老大,創建會話 tensorflow設備命名規則: /job:ps/task:0 job:ps,服務器類型 task:0,服務器第幾台 /job:worker/task:0/cpu:0 /job:worker/task:0/gpu:0 /job:worker/task:0/gpu:1 設備使用: 1、對集群當中的一些ps,worker進行指定 2、創建對應的服務, ps:創建ps服務 join() worker創建worker服務,運行模型,程序,初始化會話等等 指定一個默認的worker去做 3、worker使用設備: with tf.device("/job:worker/task:0/gup:0"): 計算操作 4、分布式使用設備: tf.train.replica_device_setter(worker_device=worker_device,cluster=cluster) 作用:通過此函數協調不同設備上的初始化操作 worker_device:為指定設備, “/job:worker/task:0/cpu:0” or "/job:worker/task:0/gpu:0" cluster:集群描述對象 API: 1、分布式會話函數:MonitoredTrainingSession(master="",is_chief=True,checkpoint_dir=None, hooks=None,save_checkpoint_secs=600,save_summaries_steps=USE_DEFAULT,save_summaries_secs=USE_DEFAULT,config=None) 參數: master:指定運行會話協議IP和端口(用於分布式) "grpc://192.168.0.1:2000" is_chief:是否為主worker(用於分布式)如果True,它將負責初始化和恢復基礎的TensorFlow會話。 如果False,它將等待一位負責人初始化或恢復TensorFlow會話。 checkpoint_dir:檢查點文件目錄,同時也是events目錄 config:會話運行的配置項, tf.ConfigProto(log_device_placement=True) hooks:可選SessionRunHook對象列表 should_stop():是否異常停止 run():跟session一樣可以運行op 2、tf.train.SessionRunHook Hook to extend calls to MonitoredSession.run() 1、begin():在會話之前,做初始化工作 2、before_run(run_context)在每次調用run()之前調用,以添加run()中的參數。 ARGS: run_context:一個SessionRunContext對象,包含會話運行信息 return:一個SessionRunArgs對象,例如:tf.train.SessionRunArgs(loss) 3、after_run(run_context,run_values)在每次調用run()后調用,一般用於運行之后的結果處理 該run_values參數包含所請求的操作/張量的結果 before_run()。 該run_context參數是相同的一個發送到before_run呼叫。 ARGS: run_context:一個SessionRunContext對象 run_values一個SessionRunValues對象, run_values.results 注:再添加鈎子類的時候,繼承SessionRunHook 3、tf.train.StopAtStepHook(last_step=5000)指定執行的訓練輪數也就是max_step,超過了就會拋出異常 tf.train.NanTensorHook(loss)判斷指定Tensor是否為NaN,為NaN則結束 注:在使用鈎子的時候需要定義一個全局步數:global_step = tf.contrib.framework.get_or_create_global_step() """
2、代碼
import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("job_name", " ", "啟動服務的類型ps or worker") tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker當中的那一台服務器以task:0 ,task:1") def main(argv): # 定義全集計數的op ,給鈎子列表當中的訓練步數使用 global_step = tf.contrib.framework.get_or_create_global_step() # 1、指定集群描述對象, ps , worker cluster = tf.train.ClusterSpec({"ps": ["10.211.55.3:2223"], "worker": ["192.168.65.44:2222"]}) # 2、創建不同的服務, ps, worker server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 根據不同服務做不同的事情 ps:去更新保存參數 worker:指定設備去運行模型計算 if FLAGS.job_name == "ps": # 參數服務器什么都不用干,是需要等待worker傳遞參數 server.join() else: worker_device = "/job:worker/task:0/cpu:0/" # 3、可以指定設備取運行 with tf.device(tf.train.replica_device_setter( worker_device=worker_device, cluster=cluster )): # 簡單做一個矩陣乘法運算 x = tf.Variable([[1, 2, 3, 4]]) w = tf.Variable([[2], [2], [2], [2]]) mat = tf.matmul(x, w) # 4、創建分布式會話 with tf.train.MonitoredTrainingSession( master= "grpc://192.168.65.44:2222", # 指定主worker is_chief= (FLAGS.task_index == 0),# 判斷是否是主worker config=tf.ConfigProto(log_device_placement=True),# 打印設備信息 hooks=[tf.train.StopAtStepHook(last_step=200)] ) as mon_sess: while not mon_sess.should_stop(): print(mon_sess.run(mat)) if __name__ == "__main__": tf.app.run()
3、分布式架構圖