tensorflow源碼解析之framework-op


目錄

  1. 什么是op
  2. op_def定義
  3. op注冊
  4. op構建與注冊輔助結構
  5. op重寫
  6. 關系圖
  7. 涉及的文件
  8. 迭代記錄

1. 什么是op

op和kernel是TF框架中最重要的兩個概念,如果一定要做一個類比的話,可以認為op相當於函數聲明,kernel相當於函數實現。舉個例子,對於矩陣相乘,我可以聲明一個op叫做MatMul,指明它的名稱,輸入,輸出,參數,以及對參數的限制等。op只是告訴我們,這個操作的目的是什么,操作內部有哪些可定制的東西,但不會提供具體實現。操作在某種設備上的具體實現方法,是由kernel決定的。TF的計算圖由節點構成,而每個節點對應了一個op,在構建計算圖時,我們只知道不同節點對應的操作是什么,而不知道運行時這個操作是怎樣實現的。也就是說,op是編譯期概念,而kernel是運行期概念

那為什么要把操作和它的實現分離呢?是為了實現TF代碼的可移植性。我們可以把TF構建的計算圖想象為Java的字節碼,而計算圖在執行的時候,需要考慮可用的設備資源,相當於我們在運行Java字節碼的時候,需要考慮當前所在的操作系統,選擇合適的字節碼實現。因為TF的目標是在多設備上運行,但我們在編碼的時候,是無法預先知道某一個操作具體是在哪種設備上運行的,因此,將操作和它的實現分離,可以讓我們在設計計算圖的時候,更專注於它的結構,而不是具體實現。當我們構建完成一個計算圖之后,在一個包含GPU的設備上,它可以利用對應操作在GPU上的kernel,充分利用GPU的高計算性能,在一個僅包含CPU的設備上,它也可以利用對應操作在CPU上的kenrel,完成計算功能。這就提高了TF代碼在不同設備之間的可移植性。

2. op_def定義

由於僅是操作的聲明,OpDef不需要包含太多的API,它被定義在一個proto中。由於這個概念極端重要,我們在這里完整列出它的代碼:

message OpDef {
    string name = 1;//操作的名稱
    message ArgDef { //對輸入輸出的定義
        string name = 1;
        string description = 2;
        DataType type = 3;//以下4個字段說明了數據的類型,詳見正文
        string type_attr = 4;
        string number_attr = 5;
        string type_list_attr = 6;
        bool is_ref = 16;//輸入或輸出是否為引用
    };
    repeated ArgDef input_arg = 2;//輸入描述
    repeated ArgDef output_arg = 3;//輸出描述
    message AttrDef {
        string name = 1;
        string type = 2;
        AttrValue default_value = 3;
        string description = 4;
        bool has_minimum = 5;
        int64 minumum = 6;
        AttrValue allowed_values = 7;
    }
    repeated AttrDef attr = 4;
    OpDeprecation deprecation = 8;
    string summary = 5;
    string description = 6;
    bool is_commutative = 18;//是否可交換,即op(a,b) == op(b,a)
    bool is_aggregate = 16;//是否可聚集
    bool is_stateful = 17;//是否帶有狀態
    bool allows_uninitialized_input = 19;//針對賦值操作
};
message OpDeprecation {
    int32 version = 1;
    string explanation = 2;
};
message OpList {
    repeated OpDef op = 1;
};

我們看到,OpDef中最核心的數據成員是操作名稱、輸入、輸出、參數。其中的參數怎樣理解呢?我們之前提到op相當於函數聲明,這個函數是帶參數的,具體使用該操作時,我們需要給參數賦予實際的數值,這個在接下來分析node_def時會詳細講到。

對於其中的幾個難理解的點,作出說明:

  • ArgDef中的3-6四個字段,是為了描述輸入或輸出的類型。當輸入或輸出是一個張量時,type或type_attr被設置為這個張量的數據類型,當輸入或輸出是一個由相同數據類型的張量構成的序列時,number_attr被設置為int對應的標識,當輸入或輸出是一個由張量構成的列表時,type_list_attr被設置為list(type)對應的標識;
  • AttrDef中的has_minimum字段,表明這個屬性是否有最小值,如果數據類型是int,那么minimum就是允許的最小值,如果數據類型是列表,那么minimum就是列表的最短長度;
  • is_aggregate這個字段,表明當前的操作是否是可聚集的,一個可聚集的操作是,能接受任意數量相同類型和形狀的輸入,並且保持輸出與每個輸入的類型和形狀相同,這個字段對於操作的優化非常重要,如果一個操作是可聚集的,並且其輸入來自多個不同的設備,那么我們就可以把聚集優化成一個樹形的操作,先在設備內部對輸入做聚集,最后在操作所在的設備集中,這樣可以提高效率。這種優化對於分布式的機器學習模型訓練非常有幫助,Spark ML中的TreeAggregate就實現了這樣的優化。可惜截止筆者看到的TF1.2版本,還沒有實現這個優化;
  • is_stateful這個字段,表明當前的操作是否是帶有狀態的,什么操作會帶有狀態呢?比如Variable;

為了方便進行OpDef的構建,TF還設計了OpDefBuilder類,它的私有數據成員如下:

class OpDefBuilder {
    //...
  private:
    OpRegistrationData op_reg_data_;
    std::vector<string> attrs_;
    std::vector<string> inputs_;
    std::vector<string> outputs_;
    string doc_;
    std::vector<string> errors_;
}

可以看到,除了errors_字段之外,其它內容幾乎就是把OpDef的結構原封不動的搬了過來。這里面我們發現了一個新的結構,OpRegistrationData,它的結構如下:

struct OpRegistrationData {
  public:
    //...
    OpDef op_def;
    OpShapeInferenceFn shape_inference_fn;
}

在這個結構中,除了我們熟知的OpDef之外,還包含了一個OpShapeInferenceFn結構,它的定義如下:

typedef std::function<Status(shape_inference::InferenceContext* c)> OpShapeInferenceFn;

這個結構的定義中,涉及到了我們后面要講到的形狀推斷的內容,這里我們只需要知道,OpShapeInferenceFn是一個幫助操作根據輸入形狀對輸出形狀進行推斷的函數即可。

3. op注冊

為了方便對操作進行統一管理,TF提出了操作注冊器的概念。對於核心數據的統一管理類型,我們並不陌生,回想之前介紹的ResourceMgr和AllocatorRegistry,原理如出一轍。因此,這個操作注冊器的作用,就是為各種操作提供一個統一的管理接口。

操作注冊類的繼承結構如下:

graph TB OpRegistryInterface-->|派生|OpRegistry OpRegistryInterface-->|派生|OpListOpRegistry

其中,OpRegistryInterface是一個接口類,它提供了注冊類最基礎的查找功能:

class OpRegistryInterface {
  public:
    //...
    //操作查找方法
    virtual Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const = 0;
    Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
}

OpRegistry就是操作注冊器,它的核心接口和數據如下:

class OpRegistry : public OpRegistryInterface {
  public:
    typedef std::function<Status(OpRegistrationData*)> OpRegistrationFactory;
    void Register(const OpRegistrationDataFactory& op_data_factory);//操作注冊
    static OpRegistry* Global();//返回一個全局靜態對象
    typedef std::function<Status<const Status&, const OpDef&)> Watcher;
    Status SetWatcher(const Watcher& watcher);
  private:
    mutable mutex mu_;
    mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
    mutable std::unordered_map<string, const OpRegistrationData*> registry_ GUARDED_BY(mu_);
    mutable bool initialized_ GUARDED_BY(mu_);
    mutable Watcher watcher_ GUARDED_BY(mu_);
}

這里面有幾個有意思的地方:

  • 注冊函數Register的輸入,是一個函數引用,這個函數接收一個OpRegistrationData指針作為輸入,那么這個函數引用的作用究竟是什么呢?它的源代碼如下,原來,我們先建立了一個OpRegistrationData的對象,然后將它作為參數傳入op_data_factory函數,這個函數會幫我們填充對象的內容,然后再用這個對象的信息進行注冊;
Status OpRegistry::RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) const {
    std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
    Status s = op_data_factory(op_reg_data.get())
    //...
}
  • Watcher是一個監視器,每次當我們注冊了一個操作的時候,在注冊步驟的最后都要調用一下這個Watcher函數,它可以方便我們對注冊的操作進行監控,所有的操作注冊動作都逃不過它的眼睛,我們可以根據自己的需要定制Watcher;
  • registry_是已注冊的操作真正存放的位置,它的結構很簡單,是一個操作名到操作數據的映射;
  • initialized_和deferred_是與注冊模式相關的兩個數據成員,注冊模式的概念接下來將會詳細闡述;

注冊器在注冊操作時,分為兩種模式,一種是即時注冊模式,一種是懶惰注冊模式。注冊模式通過initialized_字段區分,true表示即時注冊模式,false表示懶惰注冊模式。在懶惰注冊模式中,帶注冊的操作先被保存在deferred_向量中,在特定的函數調用時再將deferred_中的操作注冊到registry_,而即時注冊模式下,待注冊的操作不用經過deferred_,直接注冊到registry_。設計懶惰注冊模式的原因是,我們希望部分操作組合的注冊是原子的,即要么全部注冊,要么全部不注冊,因為這些操作之間可能會有相互依賴關系。

為了更加透徹的理解注冊模式的轉換,我們繪制了OpRegistry類中,與注冊相關的函數的調用關系,以及對initialized_的修改如下:

graph TB LookUpDef-->LookUp Register-->RegisterAlreadyLocked LookUp-->MustCallDeferred GetRegisteredOps-->MustCallDeferred Export-->MustCallDeferred ProcessRegistrations-->CallDeferred DebugString-->Export MustCallDeferred-->RegisterAlreadyLocked CallDeferred-->RegisterAlreadyLocked DeferRegistrations-.設置為false.->initialized_ MustCallDeferred-.設置為true.->initialized_ CallDeferred-.設置為true.->initialized_ OpRegistry-.設置為false.->initialized_

構造函數將initialized_設置為false,進入懶惰注冊模式,隨后一旦調用了MustCallDeferred或者CallDeferred中的任意一個,都會將initialized_設置為true,進入即時注冊模式。想要重新返回懶惰注冊模式也很簡單,只需要調用DeferRegistrations即可。

最后簡單介紹一下OpListRegistry,它允許我們用OpList初始化一個注冊器,請注意,OpList僅僅是OpDef的列表,它並不包含形狀推斷函數這個信息,因此這個注冊器中的操作,是不包含形狀推斷函數的。如果我們要查找的操作不需要形狀推斷函數,就可以使用這個注冊器。它的私有數據如下:

class OpListOpRegistry : public OpRegistryInterface {
  public:
    //...
  private:
    std::unordered_map<string, const OpRegistrationData*> index_;
}

4. op構建與注冊輔助結構

為了方便對操作的注冊,TF提出了專為注冊操作的宏,舉例如下:

REGISTER_OP("my_op_name")
    .Attr("<name>:<type>")
    .Attr("<name>:<type>=<default>")
    .Input("<name>:<type-expr>")
    .Output("<name>:<type-expr>")
    .Doc(R"(
        <1-line summary>
        <rest of the description (potensitally many lines)>
...
)");

這種寫法大大方便了注冊操作的過程。但想要實現這種宏操作,目前的類還滿足不了。TF設計了兩個類來實現這個功能,一個類為op的構建提供鏈式語法支持,另外一個類接受op構建結果,提供操作注冊功能。這兩個類分別是OpDefBuilderWrapper和OpDefBuilderReceiver。我們先來看前者:

class OpDefBuilderWrapper<true> {
  public:
    OpDefBuilderWrapper(const char name[]) : builder_(name){}
    OpDefBuilderWrapper<true>& Attr(StringPiece spec){
        builder_.Attr(spec);
        return *this;
    }
    //...
  private:
    mutable ::tensorflow::OpDefBuilder builder_;
}

有兩點比較有意思,首先顧名思義這個類基本上是對OpDefBuilder的一個封裝,提供了幾乎完全一致的API;其次,它的API都是設置型,且都返回對象本身,這就為鏈式的屬性設置提供了可能。值得注意的是,這個類名后面跟着一個true,它的含義我們待會兒揭曉。
再來看看OpDefBuilderReceiver:

struct OpDefBuilderReceiver {
    OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper);
    constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&){}
};

它提供的構造函數,以OpDefBuilderWrapper作為輸入參數,也就是說,我們可以通過賦值構造把后者直接賦值給前者,看下REGISTER_OP的宏定義:

//為了忽略不必要的細節,以下代碼做了適當刪減
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name) \
    static OpDefBuilderReceiver register_op##ctr = OpDefBuilderWrapper<SHOULD_REGISTOR_OP(name)>(name)

我們發現,REGISTER_OP繞了一圈,最終就是先用OpDefBuilderWrapper對操作進行封裝,然后把它作為參數傳遞給OpDefBuilderReceiver的構造函數,而在這個構造函數中,完成了對操作的注冊:

OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper) {
    OpRegistry::Global()->Register([wrapper](OpRegistrationData* op_reg_data) -> Status {
        return wrapper.builder().Finalize(op_reg_data);
        });
    }
}

最后我們來解釋下剛才賣的關子,OpDefBuilderWrapper<true>后面的這個true到底代表什么。我們知道,TF為我們准備了很多的操作,但有些時候我們可能用不着這所有的操作,僅需要其中一部分。如果不加限制全部編譯,會給我們的運行時庫帶來很大的負擔。因此,TF允許我們添加一個頭文件,用宏SHOULD_REGISTOR_OP定義我們想要導出的操作,比如:

#define SHOULD_REGISTOR_OP(Add) true
#define SHOULD_REGISTOR_OP(Subtract) false

表示我們希望導出Add操作,但希望屏蔽Subtract操作。這樣就能夠根據需要定制自己的TF運行時庫了。因此源代碼中除了這個OpDefBuilderWrapper<true>類之外,還有一個OpDefBuilderWrapper<false>類,最后,有些操作系統是必須要導出的,比如一些內部操作,TF為此設計了另外一個宏,可以無視SHOULD_REGISTOR_OP的宏定義,感興趣的讀者可以去看下源代碼。

5. op重寫

隨着TF的不斷拓展,操作本身也在不斷的迭代,比如重命名。為了與已有的圖實現向前兼容,TF提出了OpGenOverrides的結構,如下:

message OpGenOverride {
    string name = 1;
    bool skip = 2;//直接廢棄這個操作
    bool hide = 3;//對外隱藏
    string rename_to = 4;
    repeated string alias = 5;//更新API的名稱
    message AttrDefault {
        string name = 1;
        AttrValue value = 2;
    }
    repeated AttrDefault attr_default = 6;//修改參數默認值
    message Rename {
        string from = 1;
        string to = 2;
    }
    repeated Rename attr_rename = 7;
    repeated Rename input_rename = 8;
    repeated Rename output_rename = 9;
}
message OpGenOverrides {
    repeated OpGenOverride op = 1;
}

具體的替換操作是由OpGenOverrideMap這個類實現的,它讀入一系列包含OpGenOverrides proto的文本文件,然后允許你查找針對每個已有操作的迭代:

class OpGenOverrideMap {
  public:
    Status LoadFile(Env* env, const string& filenames);
    const OpGenOverride* ApplyOverride(OpDef* op_def) const;
  private:
    std::unordered_map<string, std::unique_ptr<OpGenOverride>> map_;
};

6. 關系圖

graph TB OpDefBuilder-.包含.->OpRegistrationData OpRegistrationData-.包含.->OpDef OpDefBuilder-.構建.->OpDef OpRegistryInterface-->|派生|OpRegistry OpRegistryInterface-->|派生|OpListOpRegistry OpDefBuilder-.包裹.->OpDefBuilderWrapper OpDefBuilderWrapper-.傳遞給.->OpDefBuilderReceiver OpDefBuilderReceiver-.注冊.->OpRegistry

7. 涉及的文件

  • op
  • op_def_builder
  • op_def
  • op_gen_lib
  • op_gen_overrides

8. 迭代記錄

  • v1.0 2018-08-26 文檔創建
  • v2.0 2018-09-09 文檔重構

github地址


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM