目錄
- 什么是op
- op_def定義
- op注冊
- op構建與注冊輔助結構
- op重寫
- 關系圖
- 涉及的文件
- 迭代記錄
1. 什么是op
op和kernel是TF框架中最重要的兩個概念,如果一定要做一個類比的話,可以認為op相當於函數聲明,kernel相當於函數實現。舉個例子,對於矩陣相乘,我可以聲明一個op叫做MatMul,指明它的名稱,輸入,輸出,參數,以及對參數的限制等。op只是告訴我們,這個操作的目的是什么,操作內部有哪些可定制的東西,但不會提供具體實現。操作在某種設備上的具體實現方法,是由kernel決定的。TF的計算圖由節點構成,而每個節點對應了一個op,在構建計算圖時,我們只知道不同節點對應的操作是什么,而不知道運行時這個操作是怎樣實現的。也就是說,op是編譯期概念,而kernel是運行期概念。
那為什么要把操作和它的實現分離呢?是為了實現TF代碼的可移植性。我們可以把TF構建的計算圖想象為Java的字節碼,而計算圖在執行的時候,需要考慮可用的設備資源,相當於我們在運行Java字節碼的時候,需要考慮當前所在的操作系統,選擇合適的字節碼實現。因為TF的目標是在多設備上運行,但我們在編碼的時候,是無法預先知道某一個操作具體是在哪種設備上運行的,因此,將操作和它的實現分離,可以讓我們在設計計算圖的時候,更專注於它的結構,而不是具體實現。當我們構建完成一個計算圖之后,在一個包含GPU的設備上,它可以利用對應操作在GPU上的kernel,充分利用GPU的高計算性能,在一個僅包含CPU的設備上,它也可以利用對應操作在CPU上的kenrel,完成計算功能。這就提高了TF代碼在不同設備之間的可移植性。
2. op_def定義
由於僅是操作的聲明,OpDef不需要包含太多的API,它被定義在一個proto中。由於這個概念極端重要,我們在這里完整列出它的代碼:
message OpDef {
string name = 1;//操作的名稱
message ArgDef { //對輸入輸出的定義
string name = 1;
string description = 2;
DataType type = 3;//以下4個字段說明了數據的類型,詳見正文
string type_attr = 4;
string number_attr = 5;
string type_list_attr = 6;
bool is_ref = 16;//輸入或輸出是否為引用
};
repeated ArgDef input_arg = 2;//輸入描述
repeated ArgDef output_arg = 3;//輸出描述
message AttrDef {
string name = 1;
string type = 2;
AttrValue default_value = 3;
string description = 4;
bool has_minimum = 5;
int64 minumum = 6;
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;
OpDeprecation deprecation = 8;
string summary = 5;
string description = 6;
bool is_commutative = 18;//是否可交換,即op(a,b) == op(b,a)
bool is_aggregate = 16;//是否可聚集
bool is_stateful = 17;//是否帶有狀態
bool allows_uninitialized_input = 19;//針對賦值操作
};
message OpDeprecation {
int32 version = 1;
string explanation = 2;
};
message OpList {
repeated OpDef op = 1;
};
我們看到,OpDef中最核心的數據成員是操作名稱、輸入、輸出、參數。其中的參數怎樣理解呢?我們之前提到op相當於函數聲明,這個函數是帶參數的,具體使用該操作時,我們需要給參數賦予實際的數值,這個在接下來分析node_def時會詳細講到。
對於其中的幾個難理解的點,作出說明:
- ArgDef中的3-6四個字段,是為了描述輸入或輸出的類型。當輸入或輸出是一個張量時,type或type_attr被設置為這個張量的數據類型,當輸入或輸出是一個由相同數據類型的張量構成的序列時,number_attr被設置為int對應的標識,當輸入或輸出是一個由張量構成的列表時,type_list_attr被設置為list(type)對應的標識;
- AttrDef中的has_minimum字段,表明這個屬性是否有最小值,如果數據類型是int,那么minimum就是允許的最小值,如果數據類型是列表,那么minimum就是列表的最短長度;
- is_aggregate這個字段,表明當前的操作是否是可聚集的,一個可聚集的操作是,能接受任意數量相同類型和形狀的輸入,並且保持輸出與每個輸入的類型和形狀相同,這個字段對於操作的優化非常重要,如果一個操作是可聚集的,並且其輸入來自多個不同的設備,那么我們就可以把聚集優化成一個樹形的操作,先在設備內部對輸入做聚集,最后在操作所在的設備集中,這樣可以提高效率。這種優化對於分布式的機器學習模型訓練非常有幫助,Spark ML中的TreeAggregate就實現了這樣的優化。可惜截止筆者看到的TF1.2版本,還沒有實現這個優化;
- is_stateful這個字段,表明當前的操作是否是帶有狀態的,什么操作會帶有狀態呢?比如Variable;
為了方便進行OpDef的構建,TF還設計了OpDefBuilder類,它的私有數據成員如下:
class OpDefBuilder {
//...
private:
OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
string doc_;
std::vector<string> errors_;
}
可以看到,除了errors_字段之外,其它內容幾乎就是把OpDef的結構原封不動的搬了過來。這里面我們發現了一個新的結構,OpRegistrationData,它的結構如下:
struct OpRegistrationData {
public:
//...
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
}
在這個結構中,除了我們熟知的OpDef之外,還包含了一個OpShapeInferenceFn結構,它的定義如下:
typedef std::function<Status(shape_inference::InferenceContext* c)> OpShapeInferenceFn;
這個結構的定義中,涉及到了我們后面要講到的形狀推斷的內容,這里我們只需要知道,OpShapeInferenceFn是一個幫助操作根據輸入形狀對輸出形狀進行推斷的函數即可。
3. op注冊
為了方便對操作進行統一管理,TF提出了操作注冊器的概念。對於核心數據的統一管理類型,我們並不陌生,回想之前介紹的ResourceMgr和AllocatorRegistry,原理如出一轍。因此,這個操作注冊器的作用,就是為各種操作提供一個統一的管理接口。
操作注冊類的繼承結構如下:
其中,OpRegistryInterface是一個接口類,它提供了注冊類最基礎的查找功能:
class OpRegistryInterface {
public:
//...
//操作查找方法
virtual Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const = 0;
Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
}
OpRegistry就是操作注冊器,它的核心接口和數據如下:
class OpRegistry : public OpRegistryInterface {
public:
typedef std::function<Status(OpRegistrationData*)> OpRegistrationFactory;
void Register(const OpRegistrationDataFactory& op_data_factory);//操作注冊
static OpRegistry* Global();//返回一個全局靜態對象
typedef std::function<Status<const Status&, const OpDef&)> Watcher;
Status SetWatcher(const Watcher& watcher);
private:
mutable mutex mu_;
mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
mutable std::unordered_map<string, const OpRegistrationData*> registry_ GUARDED_BY(mu_);
mutable bool initialized_ GUARDED_BY(mu_);
mutable Watcher watcher_ GUARDED_BY(mu_);
}
這里面有幾個有意思的地方:
- 注冊函數Register的輸入,是一個函數引用,這個函數接收一個OpRegistrationData指針作為輸入,那么這個函數引用的作用究竟是什么呢?它的源代碼如下,原來,我們先建立了一個OpRegistrationData的對象,然后將它作為參數傳入op_data_factory函數,這個函數會幫我們填充對象的內容,然后再用這個對象的信息進行注冊;
Status OpRegistry::RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) const {
std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
Status s = op_data_factory(op_reg_data.get())
//...
}
- Watcher是一個監視器,每次當我們注冊了一個操作的時候,在注冊步驟的最后都要調用一下這個Watcher函數,它可以方便我們對注冊的操作進行監控,所有的操作注冊動作都逃不過它的眼睛,我們可以根據自己的需要定制Watcher;
- registry_是已注冊的操作真正存放的位置,它的結構很簡單,是一個操作名到操作數據的映射;
- initialized_和deferred_是與注冊模式相關的兩個數據成員,注冊模式的概念接下來將會詳細闡述;
注冊器在注冊操作時,分為兩種模式,一種是即時注冊模式,一種是懶惰注冊模式。注冊模式通過initialized_字段區分,true表示即時注冊模式,false表示懶惰注冊模式。在懶惰注冊模式中,帶注冊的操作先被保存在deferred_向量中,在特定的函數調用時再將deferred_中的操作注冊到registry_,而即時注冊模式下,待注冊的操作不用經過deferred_,直接注冊到registry_。設計懶惰注冊模式的原因是,我們希望部分操作組合的注冊是原子的,即要么全部注冊,要么全部不注冊,因為這些操作之間可能會有相互依賴關系。
為了更加透徹的理解注冊模式的轉換,我們繪制了OpRegistry類中,與注冊相關的函數的調用關系,以及對initialized_的修改如下:
構造函數將initialized_設置為false,進入懶惰注冊模式,隨后一旦調用了MustCallDeferred或者CallDeferred中的任意一個,都會將initialized_設置為true,進入即時注冊模式。想要重新返回懶惰注冊模式也很簡單,只需要調用DeferRegistrations即可。
最后簡單介紹一下OpListRegistry,它允許我們用OpList初始化一個注冊器,請注意,OpList僅僅是OpDef的列表,它並不包含形狀推斷函數這個信息,因此這個注冊器中的操作,是不包含形狀推斷函數的。如果我們要查找的操作不需要形狀推斷函數,就可以使用這個注冊器。它的私有數據如下:
class OpListOpRegistry : public OpRegistryInterface {
public:
//...
private:
std::unordered_map<string, const OpRegistrationData*> index_;
}
4. op構建與注冊輔助結構
為了方便對操作的注冊,TF提出了專為注冊操作的宏,舉例如下:
REGISTER_OP("my_op_name")
.Attr("<name>:<type>")
.Attr("<name>:<type>=<default>")
.Input("<name>:<type-expr>")
.Output("<name>:<type-expr>")
.Doc(R"(
<1-line summary>
<rest of the description (potensitally many lines)>
...
)");
這種寫法大大方便了注冊操作的過程。但想要實現這種宏操作,目前的類還滿足不了。TF設計了兩個類來實現這個功能,一個類為op的構建提供鏈式語法支持,另外一個類接受op構建結果,提供操作注冊功能。這兩個類分別是OpDefBuilderWrapper和OpDefBuilderReceiver。我們先來看前者:
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name){}
OpDefBuilderWrapper<true>& Attr(StringPiece spec){
builder_.Attr(spec);
return *this;
}
//...
private:
mutable ::tensorflow::OpDefBuilder builder_;
}
有兩點比較有意思,首先顧名思義這個類基本上是對OpDefBuilder的一個封裝,提供了幾乎完全一致的API;其次,它的API都是設置型,且都返回對象本身,這就為鏈式的屬性設置提供了可能。值得注意的是,這個類名后面跟着一個true,它的含義我們待會兒揭曉。
再來看看OpDefBuilderReceiver:
struct OpDefBuilderReceiver {
OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper);
constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&){}
};
它提供的構造函數,以OpDefBuilderWrapper作為輸入參數,也就是說,我們可以通過賦值構造把后者直接賦值給前者,看下REGISTER_OP的宏定義:
//為了忽略不必要的細節,以下代碼做了適當刪減
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name) \
static OpDefBuilderReceiver register_op##ctr = OpDefBuilderWrapper<SHOULD_REGISTOR_OP(name)>(name)
我們發現,REGISTER_OP繞了一圈,最終就是先用OpDefBuilderWrapper對操作進行封裝,然后把它作為參數傳遞給OpDefBuilderReceiver的構造函數,而在這個構造函數中,完成了對操作的注冊:
OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper) {
OpRegistry::Global()->Register([wrapper](OpRegistrationData* op_reg_data) -> Status {
return wrapper.builder().Finalize(op_reg_data);
});
}
}
最后我們來解釋下剛才賣的關子,OpDefBuilderWrapper<true>
后面的這個true到底代表什么。我們知道,TF為我們准備了很多的操作,但有些時候我們可能用不着這所有的操作,僅需要其中一部分。如果不加限制全部編譯,會給我們的運行時庫帶來很大的負擔。因此,TF允許我們添加一個頭文件,用宏SHOULD_REGISTOR_OP定義我們想要導出的操作,比如:
#define SHOULD_REGISTOR_OP(Add) true
#define SHOULD_REGISTOR_OP(Subtract) false
表示我們希望導出Add操作,但希望屏蔽Subtract操作。這樣就能夠根據需要定制自己的TF運行時庫了。因此源代碼中除了這個OpDefBuilderWrapper<true>
類之外,還有一個OpDefBuilderWrapper<false>
類,最后,有些操作系統是必須要導出的,比如一些內部操作,TF為此設計了另外一個宏,可以無視SHOULD_REGISTOR_OP的宏定義,感興趣的讀者可以去看下源代碼。
5. op重寫
隨着TF的不斷拓展,操作本身也在不斷的迭代,比如重命名。為了與已有的圖實現向前兼容,TF提出了OpGenOverrides的結構,如下:
message OpGenOverride {
string name = 1;
bool skip = 2;//直接廢棄這個操作
bool hide = 3;//對外隱藏
string rename_to = 4;
repeated string alias = 5;//更新API的名稱
message AttrDefault {
string name = 1;
AttrValue value = 2;
}
repeated AttrDefault attr_default = 6;//修改參數默認值
message Rename {
string from = 1;
string to = 2;
}
repeated Rename attr_rename = 7;
repeated Rename input_rename = 8;
repeated Rename output_rename = 9;
}
message OpGenOverrides {
repeated OpGenOverride op = 1;
}
具體的替換操作是由OpGenOverrideMap這個類實現的,它讀入一系列包含OpGenOverrides proto的文本文件,然后允許你查找針對每個已有操作的迭代:
class OpGenOverrideMap {
public:
Status LoadFile(Env* env, const string& filenames);
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
private:
std::unordered_map<string, std::unique_ptr<OpGenOverride>> map_;
};
6. 關系圖
7. 涉及的文件
- op
- op_def_builder
- op_def
- op_gen_lib
- op_gen_overrides
8. 迭代記錄
- v1.0 2018-08-26 文檔創建
- v2.0 2018-09-09 文檔重構