tensorflow源碼剖析之framework-kernel


目錄

  1. 什么是kernel
  2. kernel_def
  3. op_kernel
  4. kernel的注冊
  5. op_segment
  6. 關系圖
  7. 涉及的文件
  8. 迭代記錄

1. 什么是kernel

如果說op相當於操作的聲明,那么kernel就是操作的實現。同一份聲明在不同的設備上,最優的實現方式是不一樣的,比如對於MatMul矩陣相乘這個操作,在CPU上可以用SSE指令優化加速,在GPU上可以用GPU實現高性能計算。因此就會對應CPU和GPU兩種不同的實現。所以,在定義一個kernel時,除了要指明kernel對應的op之外,還需要指明kernel所在的設備類型。

另外,kernel是一個運行期的概念。在定義圖時,每個節點的op並不知道具體是由哪個kernel實現這個操作的,因為這時節點還沒有被分配到具體的設備上,因此也就沒法為其選擇合適的kernel。

2. kernel_def

雖然kernel是一個運行期的概念,但我們仍然需要用一些靜態的信息對kernel進行描述,這個靜態的信息就是KernelDef這個proto,定義如下:

message KernelDef {
    string op = 1;//對應操作的名稱
    string device_type = 2;//對應設備的類型
    message AttrConstraint {
        string name = 1;
        AttrValue allowed_values = 2;
    }
    repeated AttrConstraint constraint = 3;//對應op中對參數的限制
    repeated string host_memory_arg = 4;//操作的輸入或輸出參數中,存在於host內存而不是device內存中的參數
    string label = 5;
}

其中的label字段我們解釋一下。有時候用戶會編寫一些實驗性的kernel,然后注冊到某一個op上去。但這種kernel默認情況下是不會被使用的,除非用戶用戶在op中定義了一個_kernel字段,並且把這個字段賦值為某個kernel的label對應的值。舉個例子,如果我們定義了一個MulKernel,且把它的label設置為'mulkernel',假設這個kernel對應的操作叫做MulOp,那么只有MulOp也包含一個字段_kernel,並且_kernel="mulkernel"時,MulKernel才可以被MulOp使用。

按照慣例,TF也會為KernelDef設計一個構建類,這就是KernelDefBuilder。與之前的構建類類似,KernelDefBuilder也只是提供了一系列的屬性設置API,私有數據成員也只有一個KernelDef指針:

class KernelDefBuilder {
    //...
  private:
    KernelDef* kernel_def_;
}

3. op_kernel

3.1 OpKernel

如果說KernelDef只是對kernel的靜態信息描述,那這里就要介紹kernel的本尊了,OpKernel是所有真正kernel的基類,它除了像KernelDef一樣包含數值屬性之外,還提供了kernel的核心API,compute函數。接下來我們仔細看一下OpKernel類的結構:

class OpKernel {
  public:
    explicit OpKernel(OpKernelConstruction* context);
    virtual void Compute(OpKernelContext* context) = 0;//執行同步計算
    virtual AsyncOpKernel* AsAsync() { return nullptr; }
  private:
    const std::unique_ptr<const NodeDef> def_;
    const DataTypeVector input_types_;//輸入的數據類型
    const MemoryTypeVector input_memory_types_;//輸入的內存類型
    const DataTypeVector output_types_;//輸出的數據類型
    const MemoryTypeVector output_memory_types_;//輸出的內存類型
    const bool is_internal_;//是否是內部操作
    NameRangeMap input_name_map_;
    NameRangeMap output_name_map_;
    bool expensive_;//是否是復雜操作
}

以下對這個類進行一些說明:

  • 在構造函數中,我們發現它需要的是一個OpKernelConstruction指針,我們猜想,這個指針應該包含了構建一個OpKernel所必須的數據成員和功能函數,我們將在下文中詳細介紹;
  • 另外,在核心API Compute中,接收的是一個OpKernelContext指針的參數,我們猜想,這個指針應該包含了OpKernel執行實際計算所需要的數據成員和功能函數。下文中也會專門介紹這個類;
  • 除此之外,我們還發現了一個指向NodeDef的指針。之前說過了,kernel是一個運行期的概念,雖然kernel是op的具體實現,但運行期中,op是跟具體的節點綁定在一起的,所以每個kernel都需要綁定一個具體的節點;
  • kernel可以被分為同步kernel和異步kernel兩類,大部分的kernel都應該是同步的,Compute函數在計算結束后返回結果狀態,但有些操作本身就是異步的,因此需要將某些kernel設計為異步的,比如網絡數據接收操作,如果這個操作被設計成同步,那么如果有其它線程在使用同一個網絡接收服務,當前線程就會被阻塞,從而造成資源的浪費。

剛才講到了異步kernel,現在我們就來看一下異步kernel對應的類,AsyncOpKernel:

class AsyncOpKernel : public OpKernel {
  public:
    typedef std::function<void()> DoneCallback;
    virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
    //...
};

可見,異步計算的API除了context之外,還需要提供一個回調函數,在異步計算執行結束之后調用。

3.2 OpKernelConstruction

剛才我們提到,OpKernel的構造函數中,需要一個類型為OpKernelConstruction指針的參數,並且猜想這個參數包含了OpKernel構建所必須的數據成員和功能函數,實際上也確實如此,我們先來看下這個類的結構:

class OpKernelConstruction {
  public:
    Status allocate_temp(DataType type, const TensorShape& shape, Tensor* out_temp);//分配一塊臨時內存
    Status allocate_persistent(DataType type, const TensorShape& shape, PersistentTensor* out_persistent, Tensor** out_tensor);//分配一塊可復用內存
    //...
  private:
    const DeviceType device_type;
    DeviceBase* const device_;
    Allocator* allocator_;
    const NodeDef* def_;
    const OpDef* op_def_;
    FunctionLibraryRuntime* flib_;
    DataTypeSlice input_types_;
    MemoryTypeSlice input_memory_types_;
    DataTypeSlice output_types_;
    MemoryTypeSlice output_memory_types_;
    const int graph_def_version_;
    Status* status_;
}

其中有幾點需要說明:

  • 關於臨時內存和可復用內存。在OpKernel構建的過程中,我們可能需要分配一些內存,有些內存是臨時性的,在OpKernel構建結束之后就沒用了,我們會自主進行申請和釋放。另外,我們也希望有些內存是可以在OpKernel的多次執行之間共享的,比如,有些kernel是有狀態的,例如Variable,我們必須在kernel構建時就給這些內容申請內存;
  • 關於可復用內存,還有一點需要說明。對於在GPU上申請的可復用內存,由於GPU不像CPU那樣內存方便管理,因此運行時需要對每一份內存的使用情況了如指掌,因此對於可復用的內存,我們必須對其每一次使用都了解。TF為此專門設計了一個PersistentTensor類,這個類是對Tensor的封裝,但對於內部張量數據只能通過一個AccessData的接口來訪問,只要我們在這個接口里設置一個Watcher,就能監控所有可復用內存的使用了;
  • 關於PersistentTensor還有一個疑點,我們知道OpKernelConstruction類能夠在OpKernel初始化時申請永久內存,但在這兩個類中,都沒有發現對它的存儲。既然永久內存時可以供kernel在不同啟動之間共享的,那在第二次啟動時怎樣找到這個張量呢?還記得在resource章節中,我們介紹的resource_op_kernel嗎?這個kernel實際上就是存儲永久張量的地方。在申請了一個永久張量之后,我們為之申請一個新的resource_op_kernel來管理它,需要這個張量時,就向這個kernel索取。具體的做法是,可以為resource_op_kernel申請一個新的節點,並在新節點與原kernel所在節點之間連接一條邊,使新節點作為老節點的輸入;
  • 私有數據中還有一個FunctionLibraryRuntime結構的指針,顧名思義,這個結構表示一個運行時的函數庫,我們將在function章節中詳細描述;

3.3 OpKernelContext

還記得我們剛才提到,OpKernel的核心API Compute函數,需要一個類型為OpKernelContext指針的輸入參數嗎?剛才我們猜想,這個類包含了執行kernel計算所需要的數據成員和功能函數,實際上也確實如此。下面我們來看下這個類的構成:

class OpKernelContext {
  public:
    //構造函數
    explicit OpKernelContext(Params* params);
    
    //輸入獲取
    const Tensor& input(int index);//獲取不可變的輸入張量
    Status input(StringPiece name, const Tensor** tensor);//獲取不可變的輸入張量
    Status input_list(StringPiece name, OpInputList* list);//獲取不可變的輸入數據列表
    Status input_ref_mutex(StringPiece name, mutex** out_mutex);//獲取可變的引用輸入
    Tensor mutable_input(int index, bool lock_held);//獲取可變的引用輸入
    Status mutable_input_list(StringPiece name, OpMultableInputList* list);//獲取可變的引用輸入列表
    void replace_ref_input(int index, const Tensor& tensor, bool lock_held);//替換某個引用輸入
    Status replace_ref_input(StringPiece name, const Tensor& tensor, bool lock_held);
    
    //輸入向輸出傳遞
    //把引用輸入轉換為引用輸出
    void forward_ref_input_to_ref_output(int input_index, int output_index);
    //將輸入轉換為指定形狀的輸出
    bool forward_input_to_output_with_shape(int input_index, int output_index, const TensorShape& output_shape, Tensor** output);
    Status forward_input_to_output_with_shape(StringPiece input_name, StringPiece output_name, const TensorShape& output_shape, Tensor** output);//同上
    //如果指定輸入1.不是引用輸入,2.與給出的屬性描述一致,3.底層的buffer的引用計數為1,那么就返回一個指向該輸入的底層數據的指針
    std::unique_ptr<Tensor> forward_input(int input_index, DataType dtype, const TensorShape& shape, MemoryType memory_type, const AllocatorAttributes& attr);
    //嘗試把指定輸入傳遞到指定輸出,如果沒有任何一個輸入可以被傳遞,那么就申請一個新的內存作為輸出
    Status forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices, int output_index, const TensorShape& output_shape, Tensor** output);
    //嘗試把指定輸入用作臨時變量,如果沒有輸入可用,則使用allocate_temp申請一個臨時buffer
    Status forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices, DataType type, const TensorShape& shape, const AllocatorAttributes& allocator_attr, Tensor* out_temp);
    
    //輸出獲取
    Status output_list(StringPiece name, OpOutputList* list);
    
    //內存分配
    Status allocate_output(int index, const TensorShape& shape, Tensor** tensor);
    Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, AllocatorAttributes attr);
    Status allocate_temp(DataType type, const TensorShape& shape, Tensor* out_temp, AllocatorAttributes allocator_attr, const AllocationAttributes& allocation_attr);
    Status allocate_persistent(DataType type, const TensorShape& shape, PersistentTensor* out_persistent, Tensor** out_tensor, AllocatorAttributes attr);
    
    //設置輸出
    Status set_output(StringPiece name, const Tensor& tensor);
    Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
    Status mutable_output(StringPiece name, Tensor** tensor);
    
    //...
    
  private:
    Status status_;
    Params* params_;
    mutable mutex mu_;
    gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
    gtl::InlinedVector<TensorValue,4> outputs_;
    ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
    bool is_output_dead_ = false;
    int64 host_temp_memory_size_;
    int64 device_temp_memory_size_;
    gtl::InlinedVector<int64, 2> host_persistent_alloc_ids_;
    gtl::InlinedVector<int64, 2> device_persistent_alloc_ids_;
    int64 host_persistent_memory_allocated_;
    int64 device_persistent_memory_allocated_;
}

為了看清楚結構,代碼我做了刪減,即便如此,代碼量仍然很大。對外的API包括了,獲取輸入,輸入傳遞到輸出,輸出設置,輸出獲取,內存分配,下面我們重點講幾個核心概念:

  • 關於輸入和輸出的類型,相信看過代碼已經有印像了,以輸入為例(輸出類似),被分為兩類,正常輸入和引用輸入。正常輸入是不可改變的,但引用輸入是可以改變的。可以理解為,如果你想改變正常輸入所在張量的內部數據,只能新建一個張量,將正常輸入的數據拷貝過來,然后改變新張量里的數據,但對於引用輸入,就可以直接對其修改,有時候還可以把修改后的引用輸入直接當作輸出。因此輸入獲取的API里,所有的API都有針對正常輸入和引用輸入兩個版本。而在輸入到輸出傳輸的API里,也專門有一個引用輸入到引用輸出的傳輸;
  • 關於內存分配。在kernel執行期間,有三種分配內存的方式,第一種是分配永久內存,也就是上文中提到的可復用內存,因為某些操作是有狀態的,在同一個kernel的多次調用之間,我們可以保留一些共享數據。第二種是分配輸出內存,一個kernel可能會輸出數據,這個數據必須先申請內存。第三種是分配臨時內存,kernel計算中可能會用到一些臨時的內存,計算結束之后就不用了;
  • 在某些情況下,一個張量即便不是被分配為輸出,也可能會被當做輸出使用。這個張量可能是一個輸入,或者是存儲在一個永久張量中,或者在究竟輸出哪個張量還沒被確定之前就被分配了。這種情況下,我們可以使用set_output或者set_output_ref函數來指定,這個張量被用作輸出。我們可以使用任何之前分配的張量作為輸出,即便這個張量是被當做臨時張量分配的。使用那些不是被allocate_output函數分配的張量當做輸出會有一定的性能損耗,因為allocate_output使用了存儲在output_attr_array中的AllocatorAttributes屬性來決定怎樣分配內存,如果使用的張量與這個內存分配的要求不符,可能會引起額外的張量內存拷貝;

在OpKernelContext的私有數據成員中,還有一個Params參數我們沒有解釋,下面就來看下它的結構:

struct Params {
    int step_id = 0;//執行的步驟編號
    OpKernel* op_kernel = nullptr;
    DeviceBase* device = nullptr;//該kernel執行時所在的設備類型
    PerOpGpuDevice* eigen_gpu_device = nullptr;
    
    //追蹤相關
    bool track_allocations = false;
    bool log_memory = false;
    bool record_tensor_accesses = false;
    
    const AllocatorAttributes* output_attr_array = nullptr;//輸出的內存分配器屬性
    ResourceMgr* resource_manager = nullptr;//當前的op_kernel可以訪問的共享資源
    ScopedStepContainer* step_container = nullptr;//屬於當前op_kernel的單步資源
    Rendezvous* rendezvous = nullptr;//通信機制
    TensorStore* tensor_store = nullptr;
    CancellationManager* cancellation_manager = nullptr;
    
    const gtl::InlindecVector<TensorValue,4>* inputs = nullptr;//當前op_kernel的輸入
    bool is_input_dead = false;
    const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs = nullptr;
    
    //設備上下文
    const gtl::InlineVector<DeviceContext*,4>* input_device_contexts = nullptr;
    DeviceContext* op_device_context = nullptr;
    
    //對控制流相關操作的支持
    FrameAndIter frame_iter;
    
    //函數調用支持
    FunctionCallFrame* call_frame = nullptr;
    FunctionLibraryRuntime* function_library = nullptr;
    std::function<void(std::function<void()>)>* runner = nullptr;
    StepStatsCollector* stats_collector = nullptr;
};

雖然Params這個名字並不起眼,實際上內部別有乾坤。這里面包含了OpKernel計算所需要的資源,其中有很多屬於運行時的資源,我們當前還沒有介紹到,因此先略過不表,等這些內容都講完之后,再回過頭來看它吧。

4. kernel的注冊

面對眾多的OpKernel,我們也需要一個集中管理的地方,於是像OpRegistry一樣,TF也設計了一個OpKernelRegistrar,本質上還是一個映射,定義如下:

class OpKernelRegistrar {
  public:
    typedef OpKernel* (*Factory)(OpKernelConstruction*);
    OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, Factory factory){
        if(kernel_def != nullptr){
            InitInternal(kernel_def, kernel_class_name, factory);
        }
    }
  private:
    void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, Factory factory);
}

似乎沒有找到存儲數據的位置?實際上答案在這個InitInternal函數中,我們先講結果,追根溯源,得到了這樣的一個結構定義:

typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;

這個結構還需要兩個函數才能發揮作用:

void* GlobalKernelRegistry() {
    static KernelRegistry* global_kernel_registry = new KernelRegistry;
    return global_kernel_registry;
}
static KernelRegistry* GlobalKernelRegistryTyped() {
    return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
}

通過把GlobalKernelRegistryTyped函數定義為static,使得它返回的數據唯一,因此我們也就得到了一個全局的KernelRegistry作為OpKernel的注冊中心。至於KernelRegistration的結構,比較簡單,還是留給讀者自己去探尋吧。

5. op_segment

有時候我們會為每個會話(Session)准備專用的kernel,因此就需要一個結構來管理每個會話的OpKernel,於是有了OpSegment類,我們看下它的結構:

class OpSegment {
  public:
    void AddHold(const string& session_handle);
    void RemoveHold(const string& session_handle);
    typedef std::function<Status(OpKernel**)> CreateKernelFn;
    Status FindOrCreate(const string& session_handle, const string& node_name, OpKernel** kernel, CreateKernelFn create_fn);
  private:
    typedef std::unordered_map<string, OpKernel*> KernelMap;
    struct Item {
        int num_holds = 1;
        KernelMap name_kernel;
        ~Item();
    };
    typedef std::unordered_map<string, Item*> SessionMap;
    mutalbe mutex mu_;
    SessionMap session_ GUARDED_BY(mu_);
    //...
};

可見,這個OpSegment類,本質上是一個SessionMap,它其實是一個SessionHandle到Item結構體的映射,而后者又是op名稱到OpKernel結構的映射。我們可以用下面的圖來表示:

graph LR SessionHandle-->|包含|Item Item-->|包含|OpKernel

我們看到在Item結構中,有一個num_holds成員,它表示有多少hold指向了某個SessionHandle,hold的作用可以理解為引用計數,防止Session被刪掉。向一個SessionHandle添加hold就是為了防止SessionHandle對應的OpKernel被刪除。

6. 關系圖

graph TB KernelDefBuilder-.創建.->KernelDef KernelDef-.描述.->OpKernel OpKernel-.包含.->OpKernel構造函數 OpKernel構造函數-.使用.->OpKernelConstruction OpKernel-.包含.->Compute函數 Compute函數-.使用.->OpKernelContext OpKernelContext-.包含.->Params OpKernel-.注冊.->OpKernelRegistrar OpSegment-.包含.->OpKernel

7. 涉及的文件

  • op_kernel
  • kernel_def_builder

8. 迭代記錄

  • v1.0 2018-08-28 文檔創建
  • v2.0 2018-09-09 文檔重構

github地址


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM