tensorflow分布式運行


1、知識點

"""
單機多卡:一台服務器上多台設備(GPU)
參數服務器:更新參數,保存參數
工作服務器:主要功能是去計算

更新參數的模式:
    1、同步模型更新
    2、異步模型更新
工作服務器會默認一個機器作為老大,創建會話


tensorflow設備命名規則:
    /job:ps/task:0   job:ps,服務器類型   task:0,服務器第幾台

    /job:worker/task:0/cpu:0
    /job:worker/task:0/gpu:0
    /job:worker/task:0/gpu:1

設備使用:
    1、對集群當中的一些ps,worker進行指定
    2、創建對應的服務, ps:創建ps服務  join()
        worker創建worker服務,運行模型,程序,初始化會話等等
        指定一個默認的worker去做
    3、worker使用設備:
        with tf.device("/job:worker/task:0/gup:0"):
            計算操作
     4、分布式使用設備:
        tf.train.replica_device_setter(worker_device=worker_device,cluster=cluster)
                作用:通過此函數協調不同設備上的初始化操作
                worker_device:為指定設備, “/job:worker/task:0/cpu:0” or "/job:worker/task:0/gpu:0"
                cluster:集群描述對象
API:
    1、分布式會話函數:MonitoredTrainingSession(master="",is_chief=True,checkpoint_dir=None,   
                    hooks=None,save_checkpoint_secs=600,save_summaries_steps=USE_DEFAULT,save_summaries_secs=USE_DEFAULT,config=None)
            參數:
                master:指定運行會話協議IP和端口(用於分布式) "grpc://192.168.0.1:2000"
                is_chief:是否為主worker(用於分布式)如果True,它將負責初始化和恢復基礎的TensorFlow會話。
                        如果False,它將等待一位負責人初始化或恢復TensorFlow會話。
                checkpoint_dir:檢查點文件目錄,同時也是events目錄
                config:會話運行的配置項, tf.ConfigProto(log_device_placement=True)
                hooks:可選SessionRunHook對象列表
                should_stop():是否異常停止
                run():跟session一樣可以運行op
    2、tf.train.SessionRunHook
            Hook to extend calls to MonitoredSession.run()
            1、begin():在會話之前,做初始化工作
            2、before_run(run_context)在每次調用run()之前調用,以添加run()中的參數。
            ARGS:
            run_context:一個SessionRunContext對象,包含會話運行信息
            return:一個SessionRunArgs對象,例如:tf.train.SessionRunArgs(loss)
            3、after_run(run_context,run_values)在每次調用run()后調用,一般用於運行之后的結果處理
            該run_values參數包含所請求的操作/張量的結果 before_run()。
            該run_context參數是相同的一個發送到before_run呼叫。
             ARGS:
            run_context:一個SessionRunContext對象
            run_values一個SessionRunValues對象, run_values.results
        注:再添加鈎子類的時候,繼承SessionRunHook
    3、tf.train.StopAtStepHook(last_step=5000)指定執行的訓練輪數也就是max_step,超過了就會拋出異常
            tf.train.NanTensorHook(loss)判斷指定Tensor是否為NaN,為NaN則結束
            注:在使用鈎子的時候需要定義一個全局步數:global_step = tf.contrib.framework.get_or_create_global_step()
"""

2、代碼

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string("job_name", " ", "啟動服務的類型ps or  worker")
tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker當中的那一台服務器以task:0 ,task:1")

def main(argv):

    # 定義全集計數的op ,給鈎子列表當中的訓練步數使用
    global_step = tf.contrib.framework.get_or_create_global_step()

    # 1、指定集群描述對象, ps , worker
    cluster = tf.train.ClusterSpec({"ps": ["10.211.55.3:2223"], "worker": ["192.168.65.44:2222"]})

    # 2、創建不同的服務, ps, worker
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

    # 根據不同服務做不同的事情 ps:去更新保存參數 worker:指定設備去運行模型計算
    if FLAGS.job_name == "ps":
        # 參數服務器什么都不用干,是需要等待worker傳遞參數
        server.join()
    else:
        worker_device = "/job:worker/task:0/cpu:0/"

        # 3、可以指定設備取運行
        with tf.device(tf.train.replica_device_setter(
            worker_device=worker_device,
            cluster=cluster
        )):
            # 簡單做一個矩陣乘法運算
            x = tf.Variable([[1, 2, 3, 4]])
            w = tf.Variable([[2], [2], [2], [2]])

            mat = tf.matmul(x, w)

        # 4、創建分布式會話
        with tf.train.MonitoredTrainingSession(
            master= "grpc://192.168.65.44:2222", # 指定主worker
            is_chief= (FLAGS.task_index == 0),# 判斷是否是主worker
            config=tf.ConfigProto(log_device_placement=True),# 打印設備信息
            hooks=[tf.train.StopAtStepHook(last_step=200)]
        ) as mon_sess:
            while not mon_sess.should_stop():
                print(mon_sess.run(mat))


if __name__ == "__main__":
    tf.app.run()

3、分布式架構圖

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM