終於又有時間和成果拿出來和大家分享,實在不容易,之前由於臨時更換任務加上入職事情多斷更了很久,現在主要在做一些KG和KGQA方面的工作。今天要和大家分享的是最近在工作中實現的分布式tensorflow。(BTW打個廣告~NLP和DL討論歡迎加群~二維碼在末尾~)
理論在這里就不詳細介紹了,說說對一些概念自己的理解吧:
(1)task->server->cluster:
這里其實應該也是分布式計算的一些基本概念,在分布式tensorflow中,采用的主從模式,即master-slave模式。有一個總控服務器來負責傳遞數據和調度,若干從節點服務器負責計算。在這里,我們所說的每一個服務器也就對應一個server。在tensorflow中,總控服務器其實叫做參數服務器(Parameter Server),在實際操作中負責參數的更新,但是並不負責圖的計算。那么負責計算的是什么呢?在這里就是工作節點(工作服務器)。在每個工作服務器上,tensorflow都會保存整張計算圖並且獨立的進行計算。不過值得注意的是,盡管叫server級別,但是不一定一個節點就只能是一個服務器,他僅僅對應服務器上的一個端口,使用某個服務器的一部分資源(或者所有資源),同時若干個工作節點也可以放在一個資源足夠的服務器上,在后面的代碼中你會看到我就是這么做的。注意到之前說的參數服務器和工作服務器都是server級別的,在這個級別下,每個服務器可以有若干個task,每個task對應一個具體的計算操作。在這個級別之上,若干個工作節點可以構成一個計算集群,而若干個參數服務器可以構成一個參數服務器集群。
(2)gRPC:
這里主要放一些干貨,介紹一些谷歌自己開發的通信協議gRPC,這也是分布式tensorflow用來做多機進程間通信的協議。額外想提以下的其實是一些tradeoff,由於現在只是跑通了demo而沒有在大的模型上做實驗,有一個需要驗證的問題就是:在沒有足夠多台服務器的情況下,到底是使用兩台服務器,將參數更新和圖計算分開,降低整個服務器的壓力,還是應該單機多卡,減少任務之間的通信開銷,這個問題需要在后面的工作中驗證,也希望有經驗的同學給出意見。
gRPC是一個高性能、開源和通用的RPC框架,面向移動和HTTP/2設計。目前提供C、Java和Go語言版本,分別是grpc、grpc-java、grpc-go。gRPC基於HTTP/2標准設計,帶來諸如雙向流、流控、頭部壓縮、單TCP連接上的多復用請求等特性。這些特性使得其在移動設備上表現更好,更省電和節省空間占用。gRPC由google開發,是一款語言中立、平台中立、開源的遠程過程調用系統。在gRPC里客戶端應用可以像調用本地對象一樣直接調用另一台不同機器上服務端應用的方法,使得你能夠更容易地創建分布式應用和服務。與許多RPC系統類似,gRPC也是基於以下理念:定義一個服務,指定其能夠被遠程調用的方法(包括參數和返回類型)。在服務端實現這個接口,並運行一個gRPC服務器來處理客戶端調用。在客戶端擁有一個存根能夠像服務端一樣的方法。
好了,理論說完了,現在要展現我和其他博主不一樣的地方了:直接上能跑的代碼!對於代碼的解釋直接見注釋部分。注意:運行代碼需要在每一個節點分別運行一次,並不是一勞永逸的哦(雖然我最開始也是這么覺得的)運行的命令如下:(demo修改自https://github.com/TracyMcgrady6/Distribute_MNIST,特別感謝)
python distributed.py --job_name=ps --task_index=0 #在參數服務器上運行,啟動參數服務器 python distributed.py --job_name=worker --task_index=0 #在工作節點上運行,啟動工作節點0 python distributed.py --job_name=worker --task_index=1 #在工作節點上運行,啟動工作節點1
上代碼~這個代碼其實是用來訓練minist的,我用的是兩個RTX2080(有木有很羡慕~),速度有多快呢?大概不到30秒就訓練完了10000步,差點沒來得及給同事看~如果有同學跑下面的代碼遇到問題可以找我要源碼~郵箱見上一條~
# encoding:utf-import math import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import tempfile import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os flags = tf.app.flags IMAGE_PIXELS = 28 # 定義默認訓練參數和數據路徑 #tf.flags其實就是定義一些命令行參數 flags.DEFINE_string('data_dir', '/tmp/mnist-data', 'Directory for storing mnist data') flags.DEFINE_integer('hidden_units', 100, 'Number of units in the hidden layer of the NN') flags.DEFINE_integer('train_steps', 10000, 'Number of training steps to perform') flags.DEFINE_integer('batch_size', 100, 'Training batch size ') flags.DEFINE_float('learning_rate', 0.01, 'Learning rate') # 定義分布式參數 # 參數服務器parameter server節點 flags.DEFINE_string('ps_hosts', '192.168.6.156:22223', 'Comma-separated list of hostname:port pairs') # 兩個worker節點 flags.DEFINE_string('worker_hosts', '192.168.6.164:22221,192.168.6.164:22220', 'Comma-separated list of hostname:port pairs') # 設置job name參數 flags.DEFINE_string('job_name', None, 'job name: worker or ps') # 設置任務的索引 flags.DEFINE_integer('task_index', None, 'Index of task within the job') # 選擇異步並行,同步並行,在本程序中其實沒有用到 flags.DEFINE_integer("issync", None, "是否采用分布式的同步模式,1表示同步模式,0表示異步模式") FLAGS = flags.FLAGS def main(unused_argv): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) if FLAGS.job_name is None or FLAGS.job_name == '': raise ValueError('Must specify an explicit job_name !') else: print ('job_name : %s' % FLAGS.job_name) if FLAGS.task_index is None or FLAGS.task_index == '': raise ValueError('Must specify an explicit task_index!') else: print ('task_index : %d' % FLAGS.task_index) ps_spec = FLAGS.ps_hosts.split(',') worker_spec = FLAGS.worker_hosts.split(',') # 創建集群 num_worker = len(worker_spec) cluster = tf.train.ClusterSpec({'ps': ps_spec, 'worker': worker_spec}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': server.join() is_chief = (FLAGS.task_index == 0) # worker_device = '/job:worker/task%d/cpu:0' % FLAGS.task_index #難點其實在這里,通過worker_device指定在同一台服務器上的不同顯卡作為工作節點 with tf.device(tf.train.replica_device_setter( worker_device = '/job:worker/task:%d/gpu:%d' %(FLAGS.task_index, FLAGS.task_index), ps_device = '/job:ps/cpu:0', cluster=cluster )): global_step = tf.Variable(0, name='global_step', trainable=False) # 創建紀錄全局訓練步數變量 hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name='hid_w') hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name='hid_b') sm_w = tf.Variable(tf.truncated_normal([FLAGS.hidden_units, 10], stddev=1.0 / math.sqrt(FLAGS.hidden_units)), name='sm_w') sm_b = tf.Variable(tf.zeros([10]), name='sm_b') x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) y_ = tf.placeholder(tf.float32, [None, 10]) hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) hid = tf.nn.relu(hid_lin) y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) opt = tf.train.AdamOptimizer(FLAGS.learning_rate) train_step = opt.minimize(cross_entropy, global_step=global_step) # 生成本地的參數初始化操作init_op init_op = tf.global_variables_initializer() train_dir = tempfile.mkdtemp() sv = tf.train.Supervisor(is_chief=is_chief, logdir=train_dir, init_op=init_op, recovery_wait_secs=1, global_step=global_step) if is_chief: print ('Worker %d: Initailizing session...' % FLAGS.task_index) else: print ('Worker %d: Waiting for session to be initaialized...' % FLAGS.task_index) #sess = sv.prepare_or_wait_for_session(server.target) #第二個坑在這里,必須要設置allow_soft_placement為True讓tensorflow可以自動找到最適合的設備,否則會說不存在gpu的kernel,同時建議運行時只安裝tensorflow_gpu config = tf.ConfigProto(allow_soft_placement = True) sess = sv.prepare_or_wait_for_session(server.target, config=config) print ('Worker %d: Session initialization complete.' % FLAGS.task_index) time_begin = time.time() print ('Traing begins @ %f' % time_begin) local_step = 0 while True: batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) train_feed = {x: batch_xs, y_: batch_ys} _, step = sess.run([train_step, global_step], feed_dict=train_feed) local_step += 1 now = time.time() print ('%f: Worker %d: traing step %d dome (global step:%d)' % (now, FLAGS.task_index, local_step, step)) if step >= FLAGS.train_steps: break time_end = time.time() print ('Training ends @ %f' % time_end) train_time = time_end - time_begin print ('Training elapsed time:%f s' % train_time) val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} val_xent = sess.run(cross_entropy, feed_dict=val_feed) print ('After %d training step(s), validation cross entropy = %g' % (FLAGS.train_steps, val_xent)) sess.close() if __name__ == '__main__': tf.app.run()
如果你喜歡博主的分享或者覺得這個分享對你有用,可以支持博主一下,以便他寫出更好的文章~
NLP討論群二維碼~