寫在前面
在對Tensorflow的后端源碼進行了拆解(參見tensorflow源碼解析系列文章索引)之后,很想跟其它深度學習框架的實現進行對比,根據框架的流行程度,先選擇了Pytorch。Pytorch的后端核心是直接復用了Caffe2,因此本文針對Caffe2源碼的core模塊進行了簡單拆解。
目錄
- 數據存儲與表示
- storage
- tensor
- blob
- qtensor
- 操作
- observer observable
- operator
- 操作求導
- operator_schema
- context
- 計算圖
- graph
- net
- transform
- 運行時
- allocator
- db
- registry
- module
- scope_guard
- workspace
- init
1. 數據存儲與表示
1.1 storage
Caffe2中對數據存儲的最底層的描述是Storage,它實際上是指向StorageImpl的共享指針,后者包含數據類型、數據指針、容量、數據所在設備等信息。Storage的定義如下:
using Storage = std::shared_ptr<StorageImpl>;
class StorageImpl {
public:
//...
protected:
using DataPtr = std::shared_ptr<void>;
int64_t capacity_ = 0;
DataType data_type_;
DataPtr data_ptr_;
DeviceType device_type_ = CPU;
};
1.2 tensor
Caffe2中的數據統一使用Tensor表示,Tensor由TensorImpl實現,后者包含一個Storage。
TensorImpl的定義如下:
class TensorImpl {
public:
//...
protected:
using DimVector = std::vector<TIndex>;
DimVector dims_; //張量的維度
TIndex size_ = -1; //張量中包含的元素數量
Storage storage_; //底層存儲
};
Tensor並非繼承自TensorImpl,而是在內部包含了一個指向TensorImpl的指針,如下:
class Tensor final {
protected:
using TensorImplPtr = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
TensorImplPtr impl_;
//...
};
對Tensor的方法調用,通過重定向給TensorImpl實現。
1.3 blob
Blob是一個容器,包含了一個指針和這個指針指向內存的數據類型,在Caffe2中,大部分情況下Blob都包含一個指向Tensor的指針。
class Blob {
public:
//...
private:
TypeMeta meta_;
void* pointer_ = nullptr;
DestroyCall destroy_ = nullptr;
};
為了方便對Blob進行傳輸,定義了其序列化和反序列化的類,分別是BlobSerializerBase和BlobDeserializerBase,以及對應的為Tensor准備的序列化和反序列化類。
1.4 qtensor
低精度的張量,為了便於快速進行低精度的整數乘法計算。具體的做法是,用更低的位數來表示整數,比如,用3個bit表示無符號整數,用4個bit表示有符號整數。低精度張量可以在略微損失模型精度的情況下,大大降低計算復雜度和存儲空間大小。
操作
2.1 Observer Observable
Caffe2使用ObserverBase和Observable兩個類實現了觀察者模式。ObserverBase是基礎觀察器,用戶可以通過繼承此類創建新的觀察器,而Observable是可被觀察屬性,用戶可以通過繼承此類獲得可觀察屬性。
ObserverBase提供了觀察器的統一接口,比較簡單:
class ObserverBase {
public:
virtual void Start() {}
virtual void Stop() {}
T* subject() const {
return subject_;
}
protected:
T* subject_;
};
其中,subject_表示被觀察對象的指針。
Observable封裝了可被觀察屬性,內部包含了一個觀察器的列表,結構如下:
class Observable {
public:
using Observer = ObserverBase<T>;
const Observer* AttachObserver(std::unique_ptr<Observer> observer){} //添加觀察器
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr){} //解除觀察器
virtual size_t NumObservers() {
return num_observers_;
} //觀察器的數量
void StartAllObservers(){} //啟動所有觀察器
void StopAllObservers(){} //關閉所有觀察器
private:
Observer* observer_cache_;
size_t num_observers_ = 0;
protected:
std::vector<std::unique_ptr<Observer>> observer_list_; //觀察器列表
};
2.2 Operator
Operator代表操作的具體實現,相當於Tensorflow中的kernel。Operator繼承自OperatorBase,而后者繼承自Observable,所以在Caffe2中,“操作”本質上是一個可觀察的對象。
OperatorBase類包含了操作需要的基本數據元素和接口:
class OperatorBase {
private:
Workspace* operator_ws_;
std::shared_ptr<const OperatorDef> operator_def_;
DeviceOption device_option_;
std::string engine_;
std::string type_;
vector<const Blob*> inputs_;
vector<Blob*> outputs_;
};
OperatorBase中包含了輸入和輸出的內存指針,可見,在Caffe2中,Operator本質上是一個運行時的對象,這與Tensorflow中Op的設計理念不同,在Tensorflow中,Op是一個編譯時對象,僅規定了操作的類型和目標,並不包含具體數據,具體的計算實際上是通過Kernel完成的。
Operator繼承自OperatorBase類:
class Operator : public OperatorBase {
public:
bool Run(int stream_id = 0) final {...}
bool RunAsync(int stream_id = 0) final {...}
virtual bool RunOnDevice() = 0;
};
實際上,Run和RunAsync最終都調用了RunOnDevice,完成實際的計算。
如果我們需要使用一些c10中定義的操作,需要將其轉換為在Caffe2中可以調用的操作,可以通過如下的宏進行轉換:
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(C10Add, C2MyAddOpName)
上述例子中,我們把一個C10Add操作,包裝成C2MyAddOpName操作,供我們使用。為了實現這個功能,Caffe2還提供了一個包裝類,C10OperatorWrapper。
2.3 操作求導
為了對操作求導,Caffe2推出了一個導數操作生成類,GradientMakerBase,方便用戶定義對於某個操作的導數。類包含的數據成員如下:
//為密集和稀疏的blob提供統一的接口
struct GradientWrapper {
string dense_;
string indices_;
string values_;
inline bool IsDense(){}
inline bool IsSparse(){}
inline bool IsEmpty(){}
};
class GradientMakerBase {
protected:
const OperatorDef& def_;
const vector<GradientWrapper>& g_output_;
vector<GradientWrapper> g_input_;
};
可見,GradientMakerBase僅提供了輸入輸出,以及原操作。用戶可以根據原操作,定制導數。
2.3 operator_schema
OpSchema是對操作的靜態描述,相當於Tensorflow中的Op,包含的信息如下:
class OpSchema {
private:
string type_;
string file_;
string doc_;
string onnx_schema_;
std::vector<Argument> args_{};
std::vector<std::pair<const char*, const char*>> input_desc_{};
std::vector<std::pair<const char*, const char*>> output_desc_{};
int line_ = 0;
int min_input_ = 0;
int max_input_ = std::numeric_limits<int>::max();
int min_output_ = 0;
int max_output_ = std::numeric_limits<int>::max();
bool private_ = false;
bool inputs_can_cross_devices_ = false;
std::function<bool(int)> num_inputs_allowed = [](int) { return true; }
std::function<bool(int)> num_outputs_allowed = [](int) { return true; }
std::function<bool(int,int)> num_inputs_outputs_allowed_ = [](int,int) { return true; }
std::function<int(int)> calculate_output_;
std::function<bool(int,int)> inplace_allowed_ = [](int,int){}
std::function<bool(int,int)> inplace_enforced_ = [](int,int){}
TensorInferenceFunctionType tensor_inference_function_ = {...}
std::unique_ptr<CostInferenceFunctionType> cost_inference_function_ = nullptr;
DeviceInferenceFunctionType device_inference_function_ = {...}
};
另外Caffe2也提供了一個對於OpSchema的注冊類OpSchemaRegistry,如下:
class OpSchemaRegistry {
private:
static CaffeMap<string, OpSchema>& map();
};
2.4 context
Caffe2中的context,其實就是Tensorflow中的OpKernelContext,為操作的實際計算提供通用的支持,主要包含內存拷貝的接口。所有實際的Context類必須繼承自BaseContext,而Caffe2為我們准備了一個標准的Context接口,CPUContext類。另外,也同樣為GPU准備了一個CUDAContext類。
3. 計算圖
3.1 graph
Graph表示圖的結構,圖包含節點,節點包含操作。
Node包含的數據成員:
class Node {
public:
OperatorDef op;
bool active = true; //操作是否被transformation刪除
std::map<int, std::vector<string>> parents;
std::vector<int, std::vector<string>> children;
}
Graph包含的私有數據成員:
class Graph {
private:
NetDef netdef_;
std::set<string> external_input_;
std::set<string> external_output_;
std::vector<Node> nodes_;
}
3.2 net
Net是一個可運行的Graph,包含了一個圖的所有“操作”,以及它們的上下文。它繼承自Observable,本質上是一個可觀察的對象。數據成員如下:
class NetBase : public Observable<NetBase>{
public:
virtual bool Run(){...}
virtual bool RunAsync();
protected:
vector<string> external_input_;
vector<string> external_output_;
string name_;
vector<const Event*> events_;
std::shared_ptr<const NetDef> net_def_;
};
NetBase派生出了三種子類,第一種是AsyncNetBase,它包含了異步執行網絡所必須的數據和接口:
class AsyncNetBase : public NetBase {
public:
bool RunAsync() override;
protected:
bool canSchedule(...);
std::vector<OperatorBase*> operators_;
std::vector<dag_utils::OperatorNode> operator_nodes_;
std::vector<std::vector<int>> chains_;
std::vector<dag_utils::OpGraphNode> chain_nodes_;
dag_utils::ExecutionChains execution_chains_;
};
第二種是SimpleNet,它表示了一種對圖的單線程的順序執行模式。
第三種是DAGNetBase,它表示了一種對圖的多線程的dag執行模式。
相關的net類形成了一個繼承體系:
3.3 transform
transform是一種針對Caffe2的NetDef結構的操作,它將NetDef作為輸入,輸出新的經過變換的NetDef。它的工作步驟包括:
- 從舊的NetDef中構建一張圖,這張圖中保存了節點的連接信息;
- 在圖中匹配指定的模式,找到它想要更改的子圖;
- 用新的操作替換匹配到的子圖;
- 根據圖構建一個新的NetDef並返回;
Transform功能的實現,依賴於三個功能函數,如下:
- PatternRule(模式規則),它決定了對於一張子圖和一個節點,是否可以將這個節點加入這個子圖中;
- ValidatorRule(驗證規則),它決定了一張子圖是否是匹配的;
- ReplaceRule(替換規則),它對一張匹配的子圖進行替換;
常用的模式如下:
- CONNECTED_SUBGRAPH,連接子圖,它只能匹配連接的子圖。比如對於圖(1)-->(2)-->(3)-->(4),它能夠匹配到[2,3]和[4,3],但不能匹配到[2,4];
- SORTED_WRT_EXECUTION_ORDER,執行序模式,它只能匹配符合執行順序的子圖,節點之間不一定需要有連接,它比General模式要快,例如對於圖(1)-->(2)-->(3)-->(4),它可以匹配到[2,4],[3,4],但不能匹配到[3,1],[4,3];
- GENERAL,它可以匹配到任何子圖,比如,對於圖(1)-->(2)-->(3)-->(4)來說,它可以匹配到子圖[2,4],[3,4],[4,2,1]等;
4. 運行時
4.1 allocator
內存分配器。
4.2 db
DB類是對kv存儲的抽象。包含了用於讀取DB數據的Cursor類,用於寫DB數據的Transaction類,DB讀取的包裹類DBReader,對DBReader進行序列化和反序列化的DBReaderSerializer和DBReaderDeserializer類。
4.3 registry
注冊類,key為字符串,value可以為任意的類。結構如下:
class Registry {
private:
CaffeMap<SrcType, Creator> registry_;
CaffeMap<SrcType, string> help_message_;
};
4.4 module
查看Caffe2已載入的模塊,以及載入指定模塊。模塊指的是動態鏈接庫。
4.5 scope_guard
是“初始化即資源獲取”原語的實現,它保證了,如果不顯式說明,函數的執行就會離開當前的scope。
4.6 workspace
Workspace包含了所有的運行時對象,包括blob和net,它是所有這些對象的擁有者,負責對這些對象進行管理。
class Workspace {
private:
typedef CaffeMap<string, unique_ptr<Blob>> BlobMap;
BlobMap blob_map_;
typedef CaffeMap<string, unique_ptr<NetBase>> NetMap;
NetMap net_map_;
const string root_folder_;
const Workspace* shared_;
std::unordered_map<string, std::pair<const Workspace*, string>> forwarded_blobs_;
std::unique_ptr<ThreadPool> thread_pool_;
std::mutex thread_pool_creation_mutex_;
std::shared_ptr<Bookkeeper> bookkeeper_;
};
4.7 init
初始化整個Caffe2的運行環境,運行機制是,把需要在環境初始化中運行的函數注冊到注冊器中,初始化時,會在不同時期運行不同注冊器中的函數。核心的函數如下:
CAFFE2_API bool GlobalInit(int* pargc, char*** argv);
整個初始化過程分為三步:
- 先運行通過REGISTER_CAFFE2_EARLY_INIT_FUNCTION注冊的函數;
- 再解析Caffe的命令行參數,並啟動日志記錄系統;
- 最后運行通過REGISTER_CAFFE2_INIT_FUNCTION注冊的函數;
寫在后面
我在github上新建了一個repo,pytorch_notes,歡迎大家點星星。