背景
[作者:DeepLearningStack,阿里巴巴算法工程師,開源TensorFlow Contributor]
歡迎大家關注我的公眾號,“互聯網西門二少”,我將繼續輸出我的技術干貨~
本篇是TensorFlow通信機制系列的第二篇文章,主要梳理使用gRPC網絡傳輸部分模塊的結構和源碼。如果讀者對TensorFlow中Rendezvous部分的基本結構和原理還不是非常了解,那么建議先從這篇文章開始閱讀。TensorFlow在最初被開源時還只是個單機的異構訓練框架,在迭代到0.8版本開始正式支持多機分布式訓練。與其他分布式訓練框架不同,Google選用了開源項目gRPC作為TensorFlow的跨機通信協議作為支持。gRPC的編程和使用其實是相對復雜的,TensorFlow為了能讓gRPC的調用更加平滑,在調用鏈封裝和抽象上面做了較多工作,甚至有些工作例如創建和管理gRPC channel涉及到了GrpcSession模塊。從個人角度來看,利用gRPC進行Tensor通信的過程已經足夠豐富,所以我們只針對gRPC傳輸Tensor過程進行梳理,至於涉及到gRPC管理方面的內容會在另一篇介紹分布式Session創建和管理的文章中集中梳理。
跨進程通信過程
根據之前寫博客的經驗,直接介紹類圖結構和源碼部分可能會讓人懵圈,還是先從邏輯上把通信過程梳理清楚更能做到深入淺出。其實對於不是非常了解分布式系統或大規模並發系統的讀者而言,TensorFlow中通信過程是有些“別扭”的。那么有的讀者可能會覺得詫異,跨進程通信過程不就是一方做Send,另一方做Recv嗎?這是一個理所當然的過程,為什么會“別扭”呢?是的,整個過程依然是一方做Send,另一方做Recv。而它的“別扭”之處就在於——真正的通信過程由Recv方觸發,而不是Send方!這就是理解TensorFlow中使用gRPC傳輸Tensor過程的最關鍵點。
前一篇文章分析過在本地傳輸的場景下Tensor通信的大體過程,從機制和邏輯上來說,跨進程傳輸過程和本地傳輸沒有很大的差異:TensorFlow使用Rendezvous通信Tensor,借助一個類似Table的數據結構作為傳輸的中轉,並且Send方和Recv方依靠ParsedKey這一唯一傳輸標識符,跨進程通信也是如此。如果讀者對這部分內容不了解,可以參考這篇文章。
Send方——將Ready的Tensor掛入本地Table
和本地傳輸場景下的Send過程相同,本地Tensor處於Ready狀態后就被放掛了本地Worker的Table中,至此Send過程就全部完成了。所以Send過程完全沒有涉及到任何跨網絡傳輸的內容,並且Send過程是非阻塞的。
Recv方——向Send方主動發出請求,觸發通信過程
Recv方是Tensor的接收方,它的處理過程是:將所需要的Tensor對應的ParsedKey拼出后,主動向Send方主動發出Request,Send方在接收到Request后立即在本地Table中查找方所需要的Tensor,找到后將Tensor封裝成Response發送回Recv方。在這個過程中,Recv方可以認為是Client,Send方可以認為是Server,通過發送Request和Response來完成Tensor的傳輸。
結構設計解析
建議讀者在閱讀本節時適當翻開TensorFlow C++部分源碼,但只需要理解結構關系即可(比如類之間的繼承、組合、依賴關系),暫時不要閱讀類的實現內容。因為RemoteRendezvous部分涉及到的類結構非常多,直接陷入細節的閱讀會深陷其中不能自拔,甚至弄得一頭霧水十分疲憊。在梳理結構時一邊參照下文中的類圖結構,一邊從設計模式和架構的角度嘗試去理解每個模塊的司職是理解本篇細節的關鍵。先理解宏觀結構看懂架子,再去深入理解實現細節嘗試去優化是讀任何代碼的正確順序。
任何場景下,通信過程幾乎都是可以通過簡單的圖將功能描述清楚的。但是不可否認的是,任何涉及到分布式通信的系統在架構上都會對通信層做相對復雜的封裝。一方面是因為通信雖然功能簡單,但其實現本身具有相對較高的復雜性(大家可以嘗試閱讀gRPC源碼感受下底層軟件的復雜度)。另一方面,應用層也需要與通信底層通過抽象盡量實現較好的解耦,這樣也方便將應用層模塊被其他團隊擴展編寫。下面我們一起來探究TensorFlow中涉及到跨進程通信的Rendezvous系列。
兩層抽象繼承關系——RemoteRendezvous與BaseRemoteRendezvous
前一篇在介紹本地傳輸時我們熟悉了Rendezvous模塊中與本地傳輸相關的類,例如LocalRendezvousImpl,IntraProcessRendezvous和SimpleRendezvous。對應地,跨進程傳輸也有不同的Rendezvous,從根源上來說,它們也繼承於Rendezvous接口,並且不同的傳輸協議也有各自的Rendezvous。在這里,我們再次將前文中展示的總體類結構圖展示出來,這次我們將涉及到遠程傳輸的類用特殊顏色標出,如下圖所示。
綜合來看,從Rendezvous的繼承結構來看,涉及到跨進程傳輸的Rendezvous有層:
1. RemoteRendezvous:只增加了一個Initialize方法,並標記為純虛函數。這是因為跨進程Rendezvous需要借助Session做一些初始化工作,所以TensorFlow中所有涉及到跨進程通信的Rendezvous都需要重寫Initialize函數,使用前也必須強制調用該函數。
2. 各種具體協議Rendezvous的基類——BaseRemoteRendezvous:既然所有涉及跨進程通信的Rendezvous都需要提供各自協議下實現的Initialize函數,那么沒有比在RemoteRendezvous和真正特化的Rendezvous之間再添加一層繼承關系更合適的做法了。事實上TensorFlow在此處也是這么設計的,這個承上啟下的類就是BaseRemoteRendezvous。它還提供了公共的Send和Recv方法,這可以讓繼承它的特化Rendezvous盡最大可能做到代碼復用。
BaseRecvTensorCall是通信的實體抽象,后面分析時會有更深的體會,在這里先有個印象即可。
開始特化——各種各樣的RemoteRendezvous
TensorFlow目標是通用可擴展,所以被設計成允許底層支持多種通信協議的結構。事實上到目前為止,算上contrib目錄的內容(contrib目錄是廣大TensorFlow貢獻者添加的內容),TensorFlow已經支持包括gRPC,RDMA(Remote Direct Memroy Access),GDR(GPU Dirrect)和MPI四種通信協議,因此包含了四種對應的Rendezvous,他們分別是RpcRemoteRendezvous,RDMARemoteRendezvous,GdrRemoteRendezvous和MPIRemoteRendezvous。每種通信協議各有其特點,有時候其可用性也取決於硬件和軟件條件(比如RDMA需要支持RDMA協議的網卡,通常跑在Infiniband和RoCE網絡上,如果沒有硬件支持,那么RDMA將無法使用,GDR也是這個道理)。從代碼中可以看出,實現每種具體的RemoteRendezvous都有一定的復雜性,所以很難想象在沒有封裝抽象和代碼復用的結構里如何實現這些內容。在本篇我們關注RpcRemoteRendezvous,它是gRPC協議實現的RemoteRendezvous。
令人熟悉的管理器模式——RendezvousMgr
為了更好地管理RemoteRendezvous,TensorFlow設計了相應的管理器——RendezvousMgr相關類,並為每種具體的RemoteRendevzous做了特化。熟悉設計模式的讀者都知道,管理器是一種經典的設計模式,它能使管理職責的變化獨立於類本身。RendezvousMgr主要負責RemoteRendezvous的創建和銷毀,它也定義了兩個本地版本的Recv接口。有的讀者可能會問,管理器為什么還允許做Recv?並且只能做本地的Recv?我個人判斷添加這兩個接口純粹是為了方便某些地方的使用。至於RendezvousMgr的創建時機和RemoteRendezvous的初始化過程並不是本篇解析的范疇,因為這涉及到分布式場景下創建Server的較長鏈路,這部分內容會在以后的博客中詳細解析。下面是RendezvousMgr相關的類圖結構,我們可以看到其接口類中已經定義了Recv接口。
RpcRemoteRendezvous通信過程與源碼解析
上一小節中對RemoteRendezvous相關類結構和類間的關系做了解析,旨在從架構層面幫助讀者理解各個類的職能。雖然涉及到的內容比較多,但是整體的結構和邏輯還是非常清晰的。如果讀者嘗試通過閱讀源碼輔助理解上述內容之后仍然感覺有些眼花繚亂,沒有關系,我們在這里暫時做一個簡單地梳理,將重點內容梳理到以下幾條。
1. 本地Rendezvous和RemoteRendezvous共同繼承了同一個接口;
2. RemoteRendezvous需要支持不同的通信協議,因此派生了各種各樣的實現類;
3. RemoteRendezvous的使用較為復雜,為此引入了管理器模式——RendezvousMgr,它負責RemoteRendezvous的創建和銷毀,並添加了兩個額外的Recv接口方便某些場景直接調用;
4. RemoteRendezvous做了兩層繼承結構只是為了添加一個Initialize方法。
本篇我們梳理使用gRPC協議的部分,從上文中梳理的結構中不難看出,這部分涉及到的類並不多。
1. Rendezvous相關類——RemoteRendezvous,BaseRemoteRendezvous,RpcRemoteRendezvous;
2. 管理器——BaseRendezvousMgr,RpcRendezvousMgr
3. 其他類——BaseRecvTensorCall,RpcRecvTensorCall和DefferedCall
畢竟是涉及到了gRPC協議本身的使用,所以有必要在梳理源碼之前從宏觀上對gRPC的工作流程做一個簡單地梳理。
gRPC編程中的代理模式——Stub與Service
在此我們假設同學們對gRPC的原理和使用有一些基本的了解,比如需要使用Protobuf預先定義Service接口,並且區分Stub和Service等。對此不了解的同學還是建議先認真閱讀一下gRPC的使用文檔和范例,下面這段文字只對gRPC做一個非常簡單的描述。
在一次RPC調用中,客戶端需要調用服務端的服務,然后將處理結果返回給客戶端。而gRPC做到了“讓客戶端調用遠端函數時就像調用本地函數一樣”的體驗,這得益於一種經典的設計模式——代理模式。負責為客戶端代理的節點(gRPC中稱之為Stub)會將請求和參數傳到服務端,並由Service進行實際的處理,然后將結果返回給Stub,最終返回到客戶端中。我們甚至可以認為負責代理的Stub就是客戶端,因為它的職責就是與遠端交互並取得結果。另外,為了能夠讓傳輸量盡可能少,也為了能夠讓傳輸不受客戶端和服務端具體的類型限制,gRPC在做跨網絡傳輸前將消息統一序列化成Protobuf格式。下圖是從gRPC官網教程中摘出的工作原理圖。
Send過程
因為Send過程並不涉及跨進程傳輸,只是將Ready的Tensor掛入本地Table之中,所以它和LocalRendezvousImpl的Send完全相同。不僅如此,TensorFlow中的任何RemoteRendezvous的Send過程都要遵循這樣的原理,基於代碼復用的考慮,將這部分內容都被抽象到了公共基類BaseRemoteRendezvous的Send函數里是一個很好的設計。事實上,BaseRemoteRendezvous的Send過程就是調用了LocalRendezvousImpl的Send過程,所以LocalRendezvousImpl必須要作為BaseRemoteRendezvous的成員之一。下面的代碼展示了這一過程。
1 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, 2 const Rendezvous::Args& args, 3 const Tensor& val, const bool is_dead) { 4 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey(); 5 { 6 mutex_lock l(mu_); 7 if (!status_.ok()) return status_; 8 DCHECK(is_initialized_locked()); 9 if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { 10 return errors::InvalidArgument( 11 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", 12 session_->worker_name); 13 } 14 } 15 // Buffers "val" and "device_context" in local_. 16 return local_->Send(parsed, args, val, is_dead); 17 }
Recv過程
Recv過程就非常復雜了,因為每種RemoteRendezvous都涉及到不同的通信協議以及管理方式,所以Recv函數是真正需要繼承重寫的模塊。在看RpcRemoteRendezvous具體的實現之前,我們必須先將gRPC定義服務的接口部分梳理清楚。
gRPC的服務定義接口文件
在TensorFlow的core/protobuf文件中,我們需要研究一下worker_service.proto文件,這個文件中定義了若干RPC Service接口。
雖然它定義了很多RPC服務接口,但是我們只需要關注和Tensor接收相關的接口定義即可。准確地說,目前我們必須要知道的是下面這個Service定義。
// See worker.proto for details. rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) { // RecvTensor Method }
顯然,這是一個讓服務端處理“接收Tensor”的服務(注意是讓服務端處理名為“接收Tensor”的服務,而不是讓服務端去接收Tensor。因為客戶端有接收Tensor的需求,但需要服務端發送Tensor,為客戶端發送Tensor的服務被稱之為“接收Tensor”),按照注釋提示,我們可以在worker.proto中找到RecvTensorRequest和RecvTensorResponse的數據結構,這部分結構讀者可以自己查閱,非常容易理解。在編譯時,擴展的Protobuf編譯器會對worker_service.proto中的rpc接口生成C++服務接口代碼和Stub代碼(畢竟Stub代碼比較純粹並且和業務邏輯無關,它只是一個向對應Service端發送處理請求的過程),TensorFlow只需要對具體的Service提供實現即可。
與gRPC生成的代碼聯系起來
gRPC會為worker_service.proto中每一個rpc服務生成C++接口代碼,為了區分多個rpc服務,特意為每個服務生成了特殊的名字。比如RecvTensor服務的名字就是/tensorflow.WorkerService/RecvTensor。為了不直接使用冗長的字符串,TensorFlow為worker_service.proto中的每個服務都做了enumeration的映射,這部分代碼在tensorflow/core/distributed_runtime/grpc_worker_service_impl.h和同名實現文件中。
1 // Names of worker methods. 2 enum class GrpcWorkerMethod { 3 kGetStatus, 4 kCreateWorkerSession, 5 kDeleteWorkerSession, 6 kRegisterGraph, 7 kDeregisterGraph, 8 kRunGraph, 9 kCleanupGraph, 10 kCleanupAll, 11 kRecvTensor, 12 kRecvBuf, 13 kLogging, 14 kTracing, 15 kCompleteGroup, 16 kCompleteInstance, 17 kGetStepSequence, 18 };
下面是從enumeration類型映射到具體字符串的函數。
1 const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { 2 switch (id) { 3 case GrpcWorkerMethod::kGetStatus: 4 return "/tensorflow.WorkerService/GetStatus"; 5 case GrpcWorkerMethod::kCreateWorkerSession: 6 return "/tensorflow.WorkerService/CreateWorkerSession"; 7 case GrpcWorkerMethod::kDeleteWorkerSession: 8 return "/tensorflow.WorkerService/DeleteWorkerSession"; 9 case GrpcWorkerMethod::kRegisterGraph: 10 return "/tensorflow.WorkerService/RegisterGraph"; 11 case GrpcWorkerMethod::kDeregisterGraph: 12 return "/tensorflow.WorkerService/DeregisterGraph"; 13 case GrpcWorkerMethod::kRunGraph: 14 return "/tensorflow.WorkerService/RunGraph"; 15 case GrpcWorkerMethod::kCleanupGraph: 16 return "/tensorflow.WorkerService/CleanupGraph"; 17 case GrpcWorkerMethod::kCleanupAll: 18 return "/tensorflow.WorkerService/CleanupAll"; 19 case GrpcWorkerMethod::kRecvTensor: 20 return "/tensorflow.WorkerService/RecvTensor"; 21 case GrpcWorkerMethod::kRecvBuf: 22 return "/tensorflow.WorkerService/RecvBuf"; 23 case GrpcWorkerMethod::kLogging: 24 return "/tensorflow.WorkerService/Logging"; 25 case GrpcWorkerMethod::kTracing: 26 return "/tensorflow.WorkerService/Tracing"; 27 case GrpcWorkerMethod::kCompleteGroup: 28 return "/tensorflow.WorkerService/CompleteGroup"; 29 case GrpcWorkerMethod::kCompleteInstance: 30 return "/tensorflow.WorkerService/CompleteInstance"; 31 case GrpcWorkerMethod::kGetStepSequence: 32 return "/tensorflow.WorkerService/GetStepSequence"; 33 } 34 // Shouldn't be reached. 35 LOG(FATAL) << "Invalid id: this line shouldn't be reached."; 36 return "invalid id"; 37 }
另外,還需要為每個RPC服務注冊為異步服務,這需要使用gRPC自帶的AddMethod接口和MarkMethodAsync接口,如下所示。
1 WorkerService::AsyncService::AsyncService() { 2 for (int i = 0; i < kGrpcNumWorkerMethods; ++i) { 3 AddMethod(new ::grpc::internal::RpcServiceMethod( 4 GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)), 5 ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); 6 ::grpc::Service::MarkMethodAsync(i); 7 } 8 }
好了,接下來就是解析源碼中具體的交互過程了。其實TensorFlow在框架層面對gRPC的使用了一些Best Practice,比如異步處理請求的架構和多線程輪詢Completion Queue等。將這些連在一起梳理需要更多的篇幅,一次性展示大量的內容也不利於閱讀,所以我們只對發送和接收過程做一個梳理。
Client端的調用鏈
從BaseRemoteRendeezvous的RecvAsync出發,逐漸深入調用鏈底層。時序圖是分析調用鏈的最好工具,下面給出了Client端到Stub的調用過程,這里面涉及到了幾個新的類。
1. RpcRecvTensorCall:這是一次gRPC調用的抽象,繼承了BaseRecvTensorCall這個抽象基類,它封裝了復雜的后續調用鏈。
2. GrpcRemoteWorker:它也是client端的內容,只不過它是Remote端的代理。
3. RpcState:這是真正封裝了一次RPC調用及狀態的類,它會直接對Stub以及GenericClientAsyncResponseReader進行管理,比如向服務端發送異步請求並等待結果等。
Client端是一個虛擬角色,它可以是調用RpcRemoteRendezvous的任何一個模塊。我們可以看到,RpcRemoteRendezvous的一次RecvRemoteAsync過程非常長,並且Stub的調用時異步的。這里的代碼確實有些多,所以我們只展示一下關鍵代碼段,但是建議讀者打開源碼仔細閱讀每個調用鏈。
下面是RecvRemoteAsync的代碼段,主要做了RpcRecvTensorCall的初始化,注冊以及啟動工作。
1 void RpcRemoteRendezvous::RecvFromRemoteAsync( 2 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, 3 DoneCallback done) { 4 CHECK(is_initialized()); 5 Status s; 6 7 // Prepare a RecvTensor call that can handle being aborted. 8 RpcRecvTensorCall* call = get_call_freelist()->New(); 9 10 // key.src_device identifies a remote device. 11 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_, 12 &call->src_rel_device_)) { 13 s = errors::Internal(parsed.src_device, 14 " is invalid remote source device."); 15 } 16 WorkerSession* sess = session(); 17 WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_); 18 if (s.ok() && rwi == nullptr) { 19 s = errors::Internal("No worker known as ", call->src_worker_); 20 } 21 22 Device* dst_device; 23 if (s.ok()) { 24 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); 25 } 26 if (!s.ok()) { 27 if (rwi != nullptr) { 28 sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); 29 } 30 get_call_freelist()->Release(call, sess->worker_cache.get()); 31 done(s, Args(), recv_args, Tensor{}, false); 32 return; 33 } 34 35 call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device, 36 recv_args, std::move(done)); 37 38 // Record "call" in active_ so that it can be aborted cleanly. 39 RegisterCall(call); 40 41 // RendezvousMgr already aborted, shouldn't send RPC call any more 42 if (!call->status().ok()) { 43 call->done()(call->status(), Args(), Args(), Tensor(), false); 44 session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_); 45 call->wi_ = nullptr; 46 get_call_freelist()->Release(call, session()->worker_cache.get()); 47 return; 48 } 49 50 // Start "call". 51 Ref(); 52 call->Start([this, call]() { 53 // Removes "call" from active_. Prevent StartAbort(). 54 DeregisterCall(call); 55 // If StartAbort was called prior to DeregisterCall, then the 56 // current status should be bad. 57 Status s = call->status(); 58 call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); 59 session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_); 60 call->wi_ = nullptr; 61 get_call_freelist()->Release(call, session()->worker_cache.get()); 62 Unref(); 63 }); 64 }
下面是GrpcRemoteWorker調用RPCState的過程,最后的IssueRequest即開始創建RPCState並觸發stub的調用。
void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, TensorResponse* response, StatusCallback done) override { VLOG(1) << "RecvTensorAsync req: " << request->DebugString(); int64 start_usec = Env::Default()->NowMicros(); // Type-specialized logging for this method. bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2); StatusCallback wrapper_done; const StatusCallback* cb_to_use; if (!logging_active) { cb_to_use = &done; // No additional work to do, so just use done directly } else { wrapper_done = [this, request, response, done, start_usec](Status s) { if (logger_->LoggingActive()) { int64 end_usec = Env::Default()->NowMicros(); int64 step_id = request->step_id(); int64 bytes = response->tensor().TotalBytes(); int64 send_start_usec = start_usec; // If a send start time was reported by the other side, use // that instead. Maybe we should mark the display if we're using // our local time instead of the remote start time? if (response->metadata().send_start_micros()) { // send_start_micros is the timestamp taken when the // remote machine began to send the RecvTensor response. // Due to clock skew between source and dest machines, it // is possible that send_start_micros can be larger than // end_usec or less than start_usec. // // To respect causality, we enforce the invariants that // the RecvTensor response can not have been sent before // the RecvTensor request, and must have been sent before // it was received. send_start_usec = std::max( start_usec, static_cast<int64>(response->metadata().send_start_micros())); send_start_usec = std::min(send_start_usec, end_usec - 1); } const string& key = request->rendezvous_key(); std::vector<string> key_parts = str_util::Split(key, ';'); if (key_parts.size() != 5) { LOG(WARNING) << "Bad key: " << key; } else { logger_->RecordRecvTensor(step_id, send_start_usec, end_usec, key_parts[3], // tensor name key_parts[0], // src_device key_parts[2], // dst_device bytes); } } VLOG(2) << "done callback, req: " << request->DebugString() << " response " << response->metadata().DebugString(); done(s); }; cb_to_use = &wrapper_done; } IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts); }
最后展示一下Stub的觸發位置,這個函數在RPCState類中,並且在創建RPCState對象時立即被調用。
1 void StartCall() { 2 context_.reset(new ::grpc::ClientContext()); 3 context_->set_fail_fast(fail_fast_); 4 5 if (timeout_in_ms_ > 0) { 6 context_->set_deadline( 7 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN)); 8 } 9 if (call_opts_) { 10 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); }); 11 } 12 13 VLOG(2) << "Starting call: " << method_; 14 15 call_ = std::move( 16 stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_)); 17 call_->StartCall(); 18 call_->Finish(&response_buf_, &status_, this); 19 }
Server端負責查找Tensor的Service
如果我們把異步處理請求的架構和多線程輪詢Completion Queue的Best Practice去除,那么Service端其實並不復雜,調用鏈相對Client端短了很多,下面的時序圖展示了自Server端接收請求后的調用過程,這里面也涉及到了幾個新的類。
1. GrpcWorkerServiceThread:這是服務端處理請求的線程類。
2. GrpcWorker:這是真正負責處理請求的Worker,是GrpcRemoteWorker的服務端版本;
3. WorkerCall:這是服務端處理一次gRPC請求和響應的類,抽象為WorkerCall,其實這也是個別名,真實的名稱較長;
4. ServerAsyncResponseWriter:這是gRPC為用戶端提供的Response writer,是承載響應的實體。
5. Utils:這其實不是一個類,而是多個工具的組合,為了在時序圖表達方便,統稱為Utils。
可以看出,服務端接收到請求后,會調用RecvLocalAsync在本地將客戶端所需要的Tensor查找出來,然后拷貝到CPU上,最后利用gRPC發送回客戶端。同樣,我們展示關鍵代碼段。
下面是GrpcWorker調用RendezvousMgr的RecvLocalAsync為客戶端尋找真正Tensor的過程。回調函數中能夠看出,在找到對應Tensor后,需要將Tensor做Encode,然后拷貝到CPU端。
1 env_->rendezvous_mgr->RecvLocalAsync( 2 step_id, parsed, 3 [opts, response, done, src_dev, request]( 4 const Status& status, const Rendezvous::Args& send_args, 5 const Rendezvous::Args& recv_args, const Tensor& val, 6 const bool is_dead) { 7 opts->ClearCancelCallback(); 8 if (status.ok()) { 9 // DMA can only be used for Tensors that do not fall into 10 // the following three odd edge cases: 1) a zero-size 11 // buffer, 2) a dead tensor which has an uninit value, and 12 // 3) the tensor has the on_host allocation attribute, 13 // i.e. it's in CPU RAM *independent of its assigned 14 // device type*. 15 const bool on_host = send_args.alloc_attrs.on_host(); 16 { 17 // Non-DMA cases. 18 if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { 19 DeviceContext* send_dev_context = send_args.device_context; 20 AllocatorAttributes alloc_attrs; 21 alloc_attrs.set_gpu_compatible(true); 22 alloc_attrs.set_on_host(true); 23 Allocator* alloc = src_dev->GetAllocator(alloc_attrs); 24 Tensor* copy = new Tensor(alloc, val.dtype(), val.shape()); 25 CHECK(send_dev_context) 26 << "send dev name: " << src_dev->name() 27 << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); 28 // "val" is on an accelerator device. Uses the device_context to 29 // fill the copy on host. 30 StatusCallback copy_ready = [response, done, copy, 31 is_dead](const Status& s) { 32 // The value is now ready to be returned on the wire. 33 grpc::EncodeTensorToByteBuffer(is_dead, *copy, response); 34 done(s); 35 delete copy; 36 }; 37 38 send_dev_context->CopyDeviceTensorToCPU( 39 &val, request->rendezvous_key(), src_dev, copy, copy_ready); 40 } else { 41 grpc::EncodeTensorToByteBuffer(is_dead, val, response); 42 done(Status::OK()); 43 } 44 } 45 } else { 46 // !s.ok() 47 done(status); 48 } 49 });
至此,我們的Rendezvous之gRPC傳輸之旅就圓滿結束了,在閱讀本篇時還是希望讀者能夠在理解結構設計后,對照C++源碼仔細閱讀反復推敲里面的每一個細節,這樣才能有更深的理解。
一個需要思考的問題——gRPC傳輸Tensor很低效?
是的,確實很低效。為什么?從設計哲學上說,gRPC本身設計並不適合深度學習訓練場景。從細節上來說它有以下幾個缺陷:
1. gRPC發送Tensor前,接收Tensor后必須要做序列化,在Tensor很大的時候這是一個非常討厭的overhead,發送接收延遲過大;
2. 序列化根本沒有對數據做任何壓縮,這是因為Tensor都是稠密的,所以序列化沒有意義;
3. 不能支持RDMA和GPU Direct。雖然這依賴於硬件,但是gRPC在軟件層面也並沒有做這些適配。
所以大部分人使用TensorFlow分布式時都會對性能有很大的抱怨,這里面很大的原因和gRPC有關。如果你使用NCCL或者MPI,那么你會得到不一樣的性能。
總結
本篇文章篇幅較長,是Rendezvous機制系列的第二篇,主要梳理了涉及到gRPC傳輸的模塊架構設計和源碼細節,並且詳細梳理了通信過程。理解TensorFlow跨機傳輸的關鍵在於理解一個事實:真正的通信過程由Recv方觸發,而不是Send方!Send依然將Ready的Tensor掛入本地Table中,而Recv會向Send端發送gRPC請求查詢所需要的Tensor,然后返回所需要的結果,這個過程雖然有些別扭,但邏輯上並不稀奇。從結構設計上來說,RemoteRendezvous沿用了Rendezvous接口,並且完全復用了LocalRendezvousImpl的Send代碼,而Recv由於涉及到具體的通信細節和管理機制,則各有各的不同。另外,RemoteRendezvous相對LocalRendezvous復雜很多,需要管理器進行管理。最后一大部分是Send和Recv的源碼細節展示,因為無論是客戶端還是服務端,其調用鏈都比較長,所以以時序圖的形式展示各個類之間的調用關系和協作關系較為清晰,具體每個調用的細節建議讀者結合源碼逐一分析,並連同本篇文章一起理解較為深刻。最后,我們總結了gRPC傳輸Tensor的明顯缺陷,當然這也是為性能優化開辟了新的空間。