tensorflow源碼解析之common_runtime-direct_session


目錄

  1. 核心概念
  2. direct_session
    1. direct_session.h
    2. direct_session.cc

1. 核心概念

讀過之前文章的讀者應該還記得,session是一個執行代理。我們把計算圖和輸入交給session,由它來調度執行器,執行計算產生結果。TF給我們提供了一個最簡單的執行器direction_session。按照當前的理解,我們覺得direction_session的實現應該是非常簡單而直接的,畢竟執行器的復雜結構我們在executor那篇已經見到了。但實際上,問題的難點在於,有時候我們只是希望以計算圖中某些節點為輸入,某些節點為輸出,來執行圖中的一小部分計算,而不需要執行整張圖,另外一個方面,這種對圖部分執行的任務,在同一張圖上可能同時存在多個。為了應對這種情況,direct_session就衍生出了很多輔助數據。

2. direct_session

2.1 direct_session.h

DirectSession類提供了豐富的數據和接口,以下為了表達簡潔,我們略去了部分函數的形參:

class DirectSession : public Session {
  public:
    DirectionSession(const SessionOptions& options, const Device* device_mgr, DirectSessionFactory* factory);
    
    Status Create(const GraphDef& graph) override;
    Status Extend(const GraphDef& graph) override;
    Status Run(...) override;//運行圖
    
    Status PRunSetup(...);//部分運行圖准備
    Status PRun(...);//部分運行圖
    
    Status Reset(const std::vector<string>& containers);//清空device_mgr中的containers,如果containers本身就是空的,那么清空默認容器
    
    Status ListDevice(...) override;
    Status Close() overrides;
    Status LocalDeviceManager(const DeviceMgr** output) overrides;
    
    void ExportCostModels(...);

  private:
    Status MaybeInitializeExecutionState(...);//給定graph之后,如果執行器狀態沒有初始化,則初始化基礎的執行器狀態
    
    Status GetOrCreateExecutors(...);//對於一組給定的輸入和輸出,在一個給定的執行器集合中檢索,是否存在合適的執行器,如果沒有,則創造一個
    
    Status CreateGraphs(...);//給定graph_def_和設備,以及輸入和輸出,創造多張圖,這些新創建的圖共享一個公共的函數庫flib_def
    
    Status ExtendLocked(const GraphDef& graph);//Extend的內部執行類
    
    Status ResourceHandleToInputTensor(...);
    
    Status SendPRunInputs(...);//將更多的輸入提供給執行器,啟動后續的執行
    
    Status RecvPRunOutputs(...);//從執行器中獲取更多的輸出,它會等待直到輸出張量計算完成
    
    Status CheckFetch(...);//檢查需求的輸出能否根據給定的輸入計算出來
    
    Status WaitForNotification(...);
    Status CheckNotClosed();
    
    const SessionOptions options_;
    
    //設備相關的結構
    const std::unique_ptr<const DeviceMgr> device_mgr_;
    std::vector<Device*> devices_;
    DeviceSet device_set_;
    
    string session_handle_;
    bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
    mutex graph_def_lock_;
    GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
    
    std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;//被用來執行op的線程池,用一個布爾值來標志,是否擁有這個線程池
    
    Status init_error_;
    
    bool sync_on_finish_ = true;//如果為真,阻塞線程直到設備已經完成了某個步驟內的所有隊列中的操作
    void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);//在線程池中調度c
    
    mutex executor_lock_;//保護執行器
    
    std::unordered_map<string, std::shared_ptr<ExecutorsAndkeys>> executor_ GUARDED_BY(executor_lock_);//由簽名映射到它的執行器,簽名包括了部分執行圖的輸入和輸出,由這兩個就能唯一確定一個部分執行圖
    
    std::unordered_map<string, std::shared_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_);//從簽名到部分執行狀態,每一個部分執行都會有一個專門保存其狀態的結構
    
    SessionState session_state_;//保存了所有當前在會話中正在存活的張量
    
    DirectSessionFactory* const factory_;
    CancellationManager* cancellation_manager_;
    
    std::unordered_map<string, string> stateful_placements_ GUARDED_BY(graph_def_lock_);//對於有狀態的節點(比如params和queue),保存節點名稱到節點所在設備的映射,一旦這些節點被放置在了某個設備上,是不允許再移動的
    
    std::unique_ptr<SimpleGraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_);//放置整張圖時使用
    
    std::unique_ptr<FunctionLibraryDefinition> flib_def_;//在任何的重寫或優化之前的函數庫,特別是,CreateGraphs函數會修改函數庫
    
    mutex closed_lock_;
    bool closed_ GUARDED_BY(closed_lock_) = false;//如果會話已經被關閉,則為true
    
    //為這個會話生成唯一的名字
    std::atomic<int64> edge_name_counter_ = {0};
    std::atomic<int64> handle_name_counter_ = {0};
    
    static std::atomic_int_fast64_t step_id_counter_;//為所有的會話生成唯一的step id
    
    const int64 operation_timeout_in_ms_ = 0;//全局對阻塞操作的超時閾值
    
    CostModelManager cost_model_manager_;//為當前會話中執行的圖管理所有的損失模型
}

可見,DirectSession里面的很多內容都是為部分執行准備的。由於計算圖僅是一個計算的規划,我們可以通過為同一張圖選取不同的輸入和輸出,來執行不同的計算。而不同的計算需要不同的執行器,也需要不同的存儲結構來保存各個計算的當前狀態。為此,TF專門給出了幾個結構體,首先我們來看一下對不同計算執行器的封裝:

//為每一個partition准備的執行器和函數運行時庫
struct PerPartionExecutorAndLib {
    Graph* graph = nullptr;
    std::unique_ptr<FunctionLibraryRuntime> flib;
    std::unique_ptr<Executor> executor;
};

//為每一次計算提供的數據結構
struct ExecutorsAndKeys {
    std::atomic_int_fast64_t step_count;
    std::unique_ptr<Graph> graph;
    NameNodeMap name_to_node;
    std::unique_ptr<FunctionLibraryDefinition> flib_def;
    std::vector<PerPartitionExecutorsAndLib> items;
    std::unordered_map<string, size_t> input_name_to_index;
    std::unordered_map<string, string> input_name_to_rendezvous_key;
    std::unordered_map<string, size_t> output_name_to_index;
    std::unordered_map<string, string> output_name_to_rendezvous_key;
    
    DataTypeVector input_types;
    DataTypeVector output_types;
};

對於一張計算圖來說,我們的每一次計算的執行,不論是完整圖的計算還是部分圖的計算,都有可能是跨設備的,因此都需要先做節點放置,把圖的節點分割到不同的設備上,每一個設備上放置了一個圖的partition,每個partition有對應的運行時函數庫和執行器。而對於每一種計算來說,我們需要一個vector把不同partition的信息存儲起來。
另外,剛才提到我們還需要為每一次計算提供保存當前狀態的結構,下面就來看一下:

//對於每一個partition內的執行,會話保存了一個RunState
struct RunState {
    mutex mu_;
    Status status GUARDED_BY(mu_);
    IntraProcessRendezvous* rendez = nullptr;
    std::unique_ptr<StepStatsCollector> collector;
    Notification executors_done;
    std::unordered_map<string, bool> pending_inputs;//如果已經提供了輸入,則為true
    std::unordered_map<string, bool> pending_outputs;//如果已經獲得了輸出,則為true
    TensorStore tensor_store;
    ScopedStepContainer step-container;
    //...
};

struct RunStateArgs {
    RunStateArgs(const DebugOption& options) : debug_options(options) {}
    bool is_partial_run = false;
    string handle;
    std::unique_ptr<Graph> graph;
    const DebugOptions& debug_options;
};

其中,RunState為每一個partition的執行提供了狀態保存的功能,而RunStateArgs則為前者提供了用於調試的參數和配置。

2.2 direct_session.cc

在源文件里,給出了DirectSessionFactory的定義,它提供了對於DirectSession進行生成和管理的功能,簡要摘錄如下:

class DirectSessionFactory : public SessionFactory {
  public:
    Session* NewSession(const SessionOptions& options) override;
    Status Reset(...) override;
    void Deregister(const DirectSession* session);
  private:
    mutex session_lock_;
    std::vector<DirectSession*> session_ GUARDED_BY(sessions_lock_);//用於存儲生成的DirectSession
};

另外,還提供了一個對於直接工廠注冊的類:

class DirectSessionRegistrar {
  public:
    DirectSessionRegistrar() {
        SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
    }
};
static DirectSessionRegistrar registrar;

下面,我們會按照順序對DirectSession內重要的函數,進行拆解,由於部分函數細節比較多,除了核心代碼之外,我們僅給出功能解釋:

DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, DirectSessionFactory* const factory){
    //根據options准備線程池
    //根據device_mgr准備device_和device_set_和每個設備的op_segment()
}

Status DirectSession::Run(...){
    //提取對於當前會話的本次運行的輸入的名稱
    //檢查對於所需的輸入輸出,是否已經存在現成的執行器
    //構造一個調用幀(call frame),方便會話與執行器之間傳遞輸入和輸出
    //創建一個運行時狀態的結構(RunState)
    //開始並行執行,核心代碼如下
    for(const auto& item : executors_and_keys->items){
        item.executor->RunAsync(args, barrier->Get());
    }
    //獲取輸出
    //保存本次運行中我們希望保存的輸出張量
    //創建並返回損失模型(cost model)
    //如果RunOptions中有相關配置,輸出分割后的圖
}

Status DirectSession::GetOrCreateExecutors(...){
    //快速查找路徑
    //慢查找路徑,對輸入和輸出做排序,使得相同輸入和輸出集合會得到相同的簽名
    //如果未找到,則創建這個執行器並緩存
    //構建執行圖,核心代碼如下
    CreateGraphs(options, &graphs, &ek->flib_def, run_state_args, &ek->input_types, &ek->output_types));
    //為各子圖准備運行時環境
}

Status DirectSession::CreateGraphs(...){
    //前期預處理
    //圖分割算法,核心代碼如下
    Partition(popts, &client_graph->graph, &partitions);
    //檢查分割結果的有效性
    //圖優化遍歷,核心代碼如下
    OptimizationPassRegistry::Global()->RunGrouping(OptimizationPassRegistry::POST_PARTITIONING, optimization_options);
    //允許設備重寫它擁有的子圖
}

可見,具體的執行過程是在Run函數內部,調用executor->RunAsync函數來實現的,在具體執行之前,我們還需要通過GetOrCreateExecutors函數獲得執行器,在這個函數內部,我們通過CreateGraphs函數對原圖進行了分割,並利用圖優化遍歷算法對圖進行了優化。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM