目錄
- 什么是形狀推斷
- InferenceContext
- 關系圖
- 涉及的文件
- 迭代記錄
1. 什么是形狀推斷
前面我們講到op的時候,提到了操作的注冊器OpRegistry,並且提到,其中注冊的數據是一個結構OpRegistrationData,這個結構中除了OpDef之外,還包含了一個OpShapeInferenceFn,這個數據是做什么用的呢?
我們知道,op只是定義了操作的輸入輸出和參數,但並沒有定義操作具體的輸入形狀,舉個例子,MatMul操作,代表矩陣乘法,這只是一個抽象的表示,沒有具體說,這個矩陣乘法代表的是[2,3]x[3,4]=[2,4],還是[100,200]x[200,300]=[100,300]。所以在實際應用中,在得到輸入之前,輸出的真實形狀是無法預知的,但在得到輸入之后,我們必須能夠根據輸入的形狀,以及當前op的作用,判斷輸出的具體形狀,才能給它申請對應大小的內存空間。所以,我們需要為每一個操作,配備一個形狀推斷的函數,這就是形狀推斷的由來。
2. InferenceContext
前面提到了OpShapeInferenceFn,我們來看一下它的詳細定義:
typedef std::function<Status(shape_inference::InferenceContext* c)> OpShapeInferenceFn;
可見,OpShapeInferenceFn是一個接收InferenceContext參數的函數,TF為所有op的形狀推斷函數,准備了這樣一個統一的接口。所有跟形狀推斷相關的數據和功能函數,都放在InferenceContext這個類的內部。回想一下前面講過的OpKernelContext,其實它們的功能很像。OpKernelContext是作為OpKernel的核心API Compute函數的參數,所有計算相關的參數都會包含在這個對象中。InferenceContext也是一樣,我們把所有跟形狀推斷相關的數據和功能函數封裝在一個InferenceContext對象中,然后把這個對象傳遞給OpShapeInferenceFn,就可以實現形狀推斷。這種設計實現了數據部分和實現邏輯的解耦。
在具體看ShapeInference類之前,我們先要看一些輔助類:
class Dimension {
private:
//...
const int64 value_;
};
class DimensionHandle {
private:
//...
const Dimension* ptr_ = nullptr;
};
class Shape {
//...
private:
const int32 rank_;
const std::vector<DimensionHandle> dims_;
};
class ShapeHandle {
//...
private:
const Shape* ptr = nullptr;
};
class DimensionOrConstant {
public:
//...
DimensionHandle dim;
int64 val;
};
class ShapeAndType {
ShapeHandle shape;
DataType dtype = DT_INVALID;
};
這幾個類都比較簡單。在下面用到時能夠認得就好了。
下面我們看下InferenceContext這個類:
class InferenceContext {
public:
InferenceContext(int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types);//構造函數
Status Run(const std::function<Status(shape_inference::InferenceContext* c)>& fn);//運行一個以this為參數的函數,沒錯,這里運行的就是OpShapeInferenceFn
bool MergeInput(int idx, ShapeHandle shape);
bool RelaxInput(int idx, ShapeHandle shape);
private:
ShapeManager shape_manager_;
std::vector<ShapeHandle> inputs_;
std::vector<const Tensor*> input_tensors_;
std::vector<bool> requested_input_tensor_;
std::vector<ShapeHandle> outputs_;
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types_;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> output_handle_shapes_and_types_;
const int graph_def_version_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
Status construction_status_;
};
前面已經介紹過了這個類的作用,是作為真正的形狀推斷函數的參數,為形狀推斷提供足夠的數據和功能函數支持,那么這個類的成員就比較清晰了,私有數據成員為形狀推斷提供數據支持,而公有API,為形狀推斷提供公用的功能函數,比如上面提到的MergeInput和RelaxOutput,下面我們重點介紹下這兩個函數的功能:
MergeInput函數是將輸入索引idx處的輸入與shape合並,具體的合並規則是:
- 如果ShapeHandles是一樣的,或者shape是未知的,那么輸入維度不變。否則,如果輸入維度是未知的,那么輸出是shape;
- 如果兩個形狀都是已知的,它們必須擁有相同的rank;
- 對於任意一個維度,如果在兩個形狀中這個維度都已知,那么它們必須相等;
- 如果一個形狀在任意維度上的信息都多於另一個形狀,那么擁有更多信息的形狀將被返回。否則,一個新的形狀將被構建並返回,這個新的形狀綜合了輸入的兩個形狀的信息;
- 比如,合並[2,?]和[?,2]將得到[2,2];
- 比如,[2,2]不能被合並到[1,2]
如果說MergeInput函數對輸入形狀是“收縮”的,那么“RelaxInput”函數對輸入形狀就是“擴張”的,它傾向於讓形狀變的更模糊,具體的規則是:
- 如果ShapeHandles是一樣的,那么對應的shape將會被返回;
- 如果任一個ShapeHandle是未知的,那么一個未知的ShapeHandle將會被返回;
- 如果兩個形狀的rank已知,但不同,那么一個未知ShapeHandle將會被返回;
- 對於任一維度,如果任一shape是未知的,那么對應的輸出維度也是未知的;
- 對於任一維度,如果兩個shape對應的維度位置都是已知的,但並不相同,那么對應的輸出維度也是未知的;
- 如果兩個shape的rank和對應維度大小都一樣,那么這個形狀將會被返回;
- 例如,[2,?]和[?,2]會得到[?,?];
- 例如,[2,2]和[3,2]會得到[?,2];
- 例如,[2,2]和[1,2,3]會得到?
3. 關系圖
4. 涉及的文件
- shape_inference
5. 迭代記錄
- v1.0 2018-08-29 文檔創建
- v2.0 2018-09-10 文檔重構