目錄
- 核心概念
- direct_session
- direct_session.h
- direct_session.cc
1. 核心概念
讀過之前文章的讀者應該還記得,session是一個執行代理。我們把計算圖和輸入交給session,由它來調度執行器,執行計算產生結果。TF給我們提供了一個最簡單的執行器direction_session。按照當前的理解,我們覺得direction_session的實現應該是非常簡單而直接的,畢竟執行器的復雜結構我們在executor那篇已經見到了。但實際上,問題的難點在於,有時候我們只是希望以計算圖中某些節點為輸入,某些節點為輸出,來執行圖中的一小部分計算,而不需要執行整張圖,另外一個方面,這種對圖部分執行的任務,在同一張圖上可能同時存在多個。為了應對這種情況,direct_session就衍生出了很多輔助數據。
2. direct_session
2.1 direct_session.h
DirectSession類提供了豐富的數據和接口,以下為了表達簡潔,我們略去了部分函數的形參:
class DirectSession : public Session {
public:
DirectionSession(const SessionOptions& options, const Device* device_mgr, DirectSessionFactory* factory);
Status Create(const GraphDef& graph) override;
Status Extend(const GraphDef& graph) override;
Status Run(...) override;//運行圖
Status PRunSetup(...);//部分運行圖准備
Status PRun(...);//部分運行圖
Status Reset(const std::vector<string>& containers);//清空device_mgr中的containers,如果containers本身就是空的,那么清空默認容器
Status ListDevice(...) override;
Status Close() overrides;
Status LocalDeviceManager(const DeviceMgr** output) overrides;
void ExportCostModels(...);
private:
Status MaybeInitializeExecutionState(...);//給定graph之后,如果執行器狀態沒有初始化,則初始化基礎的執行器狀態
Status GetOrCreateExecutors(...);//對於一組給定的輸入和輸出,在一個給定的執行器集合中檢索,是否存在合適的執行器,如果沒有,則創造一個
Status CreateGraphs(...);//給定graph_def_和設備,以及輸入和輸出,創造多張圖,這些新創建的圖共享一個公共的函數庫flib_def
Status ExtendLocked(const GraphDef& graph);//Extend的內部執行類
Status ResourceHandleToInputTensor(...);
Status SendPRunInputs(...);//將更多的輸入提供給執行器,啟動后續的執行
Status RecvPRunOutputs(...);//從執行器中獲取更多的輸出,它會等待直到輸出張量計算完成
Status CheckFetch(...);//檢查需求的輸出能否根據給定的輸入計算出來
Status WaitForNotification(...);
Status CheckNotClosed();
const SessionOptions options_;
//設備相關的結構
const std::unique_ptr<const DeviceMgr> device_mgr_;
std::vector<Device*> devices_;
DeviceSet device_set_;
string session_handle_;
bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
mutex graph_def_lock_;
GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;//被用來執行op的線程池,用一個布爾值來標志,是否擁有這個線程池
Status init_error_;
bool sync_on_finish_ = true;//如果為真,阻塞線程直到設備已經完成了某個步驟內的所有隊列中的操作
void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);//在線程池中調度c
mutex executor_lock_;//保護執行器
std::unordered_map<string, std::shared_ptr<ExecutorsAndkeys>> executor_ GUARDED_BY(executor_lock_);//由簽名映射到它的執行器,簽名包括了部分執行圖的輸入和輸出,由這兩個就能唯一確定一個部分執行圖
std::unordered_map<string, std::shared_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_);//從簽名到部分執行狀態,每一個部分執行都會有一個專門保存其狀態的結構
SessionState session_state_;//保存了所有當前在會話中正在存活的張量
DirectSessionFactory* const factory_;
CancellationManager* cancellation_manager_;
std::unordered_map<string, string> stateful_placements_ GUARDED_BY(graph_def_lock_);//對於有狀態的節點(比如params和queue),保存節點名稱到節點所在設備的映射,一旦這些節點被放置在了某個設備上,是不允許再移動的
std::unique_ptr<SimpleGraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_);//放置整張圖時使用
std::unique_ptr<FunctionLibraryDefinition> flib_def_;//在任何的重寫或優化之前的函數庫,特別是,CreateGraphs函數會修改函數庫
mutex closed_lock_;
bool closed_ GUARDED_BY(closed_lock_) = false;//如果會話已經被關閉,則為true
//為這個會話生成唯一的名字
std::atomic<int64> edge_name_counter_ = {0};
std::atomic<int64> handle_name_counter_ = {0};
static std::atomic_int_fast64_t step_id_counter_;//為所有的會話生成唯一的step id
const int64 operation_timeout_in_ms_ = 0;//全局對阻塞操作的超時閾值
CostModelManager cost_model_manager_;//為當前會話中執行的圖管理所有的損失模型
}
可見,DirectSession里面的很多內容都是為部分執行准備的。由於計算圖僅是一個計算的規划,我們可以通過為同一張圖選取不同的輸入和輸出,來執行不同的計算。而不同的計算需要不同的執行器,也需要不同的存儲結構來保存各個計算的當前狀態。為此,TF專門給出了幾個結構體,首先我們來看一下對不同計算執行器的封裝:
//為每一個partition准備的執行器和函數運行時庫
struct PerPartionExecutorAndLib {
Graph* graph = nullptr;
std::unique_ptr<FunctionLibraryRuntime> flib;
std::unique_ptr<Executor> executor;
};
//為每一次計算提供的數據結構
struct ExecutorsAndKeys {
std::atomic_int_fast64_t step_count;
std::unique_ptr<Graph> graph;
NameNodeMap name_to_node;
std::unique_ptr<FunctionLibraryDefinition> flib_def;
std::vector<PerPartitionExecutorsAndLib> items;
std::unordered_map<string, size_t> input_name_to_index;
std::unordered_map<string, string> input_name_to_rendezvous_key;
std::unordered_map<string, size_t> output_name_to_index;
std::unordered_map<string, string> output_name_to_rendezvous_key;
DataTypeVector input_types;
DataTypeVector output_types;
};
對於一張計算圖來說,我們的每一次計算的執行,不論是完整圖的計算還是部分圖的計算,都有可能是跨設備的,因此都需要先做節點放置,把圖的節點分割到不同的設備上,每一個設備上放置了一個圖的partition,每個partition有對應的運行時函數庫和執行器。而對於每一種計算來說,我們需要一個vector把不同partition的信息存儲起來。
另外,剛才提到我們還需要為每一次計算提供保存當前狀態的結構,下面就來看一下:
//對於每一個partition內的執行,會話保存了一個RunState
struct RunState {
mutex mu_;
Status status GUARDED_BY(mu_);
IntraProcessRendezvous* rendez = nullptr;
std::unique_ptr<StepStatsCollector> collector;
Notification executors_done;
std::unordered_map<string, bool> pending_inputs;//如果已經提供了輸入,則為true
std::unordered_map<string, bool> pending_outputs;//如果已經獲得了輸出,則為true
TensorStore tensor_store;
ScopedStepContainer step-container;
//...
};
struct RunStateArgs {
RunStateArgs(const DebugOption& options) : debug_options(options) {}
bool is_partial_run = false;
string handle;
std::unique_ptr<Graph> graph;
const DebugOptions& debug_options;
};
其中,RunState為每一個partition的執行提供了狀態保存的功能,而RunStateArgs則為前者提供了用於調試的參數和配置。
2.2 direct_session.cc
在源文件里,給出了DirectSessionFactory的定義,它提供了對於DirectSession進行生成和管理的功能,簡要摘錄如下:
class DirectSessionFactory : public SessionFactory {
public:
Session* NewSession(const SessionOptions& options) override;
Status Reset(...) override;
void Deregister(const DirectSession* session);
private:
mutex session_lock_;
std::vector<DirectSession*> session_ GUARDED_BY(sessions_lock_);//用於存儲生成的DirectSession
};
另外,還提供了一個對於直接工廠注冊的類:
class DirectSessionRegistrar {
public:
DirectSessionRegistrar() {
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
}
};
static DirectSessionRegistrar registrar;
下面,我們會按照順序對DirectSession內重要的函數,進行拆解,由於部分函數細節比較多,除了核心代碼之外,我們僅給出功能解釋:
DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, DirectSessionFactory* const factory){
//根據options准備線程池
//根據device_mgr准備device_和device_set_和每個設備的op_segment()
}
Status DirectSession::Run(...){
//提取對於當前會話的本次運行的輸入的名稱
//檢查對於所需的輸入輸出,是否已經存在現成的執行器
//構造一個調用幀(call frame),方便會話與執行器之間傳遞輸入和輸出
//創建一個運行時狀態的結構(RunState)
//開始並行執行,核心代碼如下
for(const auto& item : executors_and_keys->items){
item.executor->RunAsync(args, barrier->Get());
}
//獲取輸出
//保存本次運行中我們希望保存的輸出張量
//創建並返回損失模型(cost model)
//如果RunOptions中有相關配置,輸出分割后的圖
}
Status DirectSession::GetOrCreateExecutors(...){
//快速查找路徑
//慢查找路徑,對輸入和輸出做排序,使得相同輸入和輸出集合會得到相同的簽名
//如果未找到,則創建這個執行器並緩存
//構建執行圖,核心代碼如下
CreateGraphs(options, &graphs, &ek->flib_def, run_state_args, &ek->input_types, &ek->output_types));
//為各子圖准備運行時環境
}
Status DirectSession::CreateGraphs(...){
//前期預處理
//圖分割算法,核心代碼如下
Partition(popts, &client_graph->graph, &partitions);
//檢查分割結果的有效性
//圖優化遍歷,核心代碼如下
OptimizationPassRegistry::Global()->RunGrouping(OptimizationPassRegistry::POST_PARTITIONING, optimization_options);
//允許設備重寫它擁有的子圖
}
可見,具體的執行過程是在Run函數內部,調用executor->RunAsync函數來實現的,在具體執行之前,我們還需要通過GetOrCreateExecutors函數獲得執行器,在這個函數內部,我們通過CreateGraphs函數對原圖進行了分割,並利用圖優化遍歷算法對圖進行了優化。