TVM,Relay,Pass


TVM,Relay,Pass

Relay介紹

主要結合TVM的文檔(https://tvm.apache.org/docs/dev/relay_intro.html),介紹一下NNVM的第二代Relay。Relay的設計目標有以下幾點:

支持傳統的數據流(DataFlow)風格編程。支持functional-style scoping,並融合了編程語言領域的一些知識,帶了一些新的特性(支持Let表達式,支持遞歸等等)支持數據流風格和函數式風格混合編程。

使用Relay建立一個計算圖

傳統的深度學習框架使用計算圖作為的中間表示。計算圖(或數據流圖)是代表計算過程的有向無環圖(DAG)。盡管由於缺少控制流,數據流圖在計算能力方面受到限制,但簡單性使其易於實現自動微分,並針對異構執行環境進行編譯(例如,在專用硬件上執行計算圖的某些部分,即子圖)。

 

 

 使用Relay構建一個簡單的計算圖示例代碼,對應的文本形式和AST抽象語法樹,可以使用Relay來構建一個計算(DataFlow)圖。具體來說,上面的代碼顯示了如何構造一個簡單的兩個節點的計算圖,可以發現這個示例的代碼和現有的Garph IR如NNVMv1沒有太大區別,唯一的區別是在術語方面:

現有框架通常使用圖和子圖Relay使用函數,例如 – fn(%x),表示圖每個數據流節點,都是Relay中的一個CallNode。通過Relay的Python DSL,可以快速構建計算圖。上面的代碼需要注意,這里顯示構造了一個Add節點,兩個輸入都指向%1。當一個深度學習框架。對上面的計算圖進行推理時,將會按照拓撲序進行計算,並且%1只會被計算一次。雖然這個事實對於深度學習框架的開發者,一件很自然的事情,但這或許會使得只關心算法的研究員困惑。如果實現一個簡單的vistor打印結果,將結果視為嵌套的Call表達式,將是log(%x) + log(%x)。

當DAG中存在共享節點時,這種歧義是由程序語義的解釋不同引起的。在正常的函數式編程IR中,嵌套表達式被視為表達式樹,沒有考慮%1,實際上在%2中被重用了2次的事實。

Relay IR注意到了這個區別。其實深度學習框架用戶,經常使用這種方式構建計算圖,其中經常發生DAG節點重用。然后以文本格式打印Relay程序時,每行打印一個CallNode,並為每個CallNode分配一個臨時ID(%1, %2),以便可以在程序的后續部分中引用每個公共節點。

Module:支持多個函數(Graphs)

上面介紹了如何構建一個數據流圖為一個函數。然后一個很自然的問題是可以做到構建多個函數並相互調用嗎?Relay允許將多個函數組合在一個Module中,下面的代碼展示了一個函數調用另外一個函數的例子。

def @muladd(%x, %y, %z) { %1 = mul(%x, %y) %2 = add(%1, %z) %2}def @myfunc(%x) { %1 = @muladd(%x, 1, 2) %2 = @muladd(%1, 2, 3) %2}

Module可以被看作Map<GlobalVar, Function>,GlobalVar僅僅是一個表示函數名的ID,上面的程序中GlobalVar是@muladd和@myfunc。當一個CallNode調用另外一個函數時,相應的GlobalVar被存在CallNode的OP中。包含了一個間接的等級關系---需要使用相應的GlobalVar,從Module中查找調用函數的主體。也可以直接將引用的函數存儲為CallNode中的OP。為什么需要引入GlobalVar呢?主要原因是為了解耦定義和聲明,並支持了函數的遞歸和延遲聲明。

def @myfunc(%x) { %1 = equal(%x, 1)if (%1) { %x } else { %2 = sub(%x, 1) %3 = @myfunc(%2) %4 = add(%3, %3) %4 }}在上面的例子中,@myfunc遞歸調用。使用GlobalVar @myfunc表示函數,避免了數據結構中的循環依賴性。至此,已經介紹完了Relay中的基本概念。相比NNVM,Relay在如下方面進行了改進:

有文本形式中間表示,便於開發和 debug支持子圖函數、聯合模塊,便於聯合優化前端用戶友好,便於調優0x2.3 Let Binding and Scopes

至此,已經介紹了如何用深度學習框架中的舊方法,構建計算圖。這一節將討論一個Relay的一個新的構造-let bindings。

Let binding被每一種高級的編程語言應用。在Relay中,一個擁有三個字段Let(var, value, body)的數據結構。計算一個Let表達式時,首先計算value部分,然后將其綁定到var,最后在body表達式中返回計算結果。

可以使用一系列的Let綁定,構造一個邏輯上等效於數據流程序的程序,下面的代碼示例顯示了這個用法:

 

 

 Let表達式構造和數據流程序等價的,計算圖嵌套的Let Binding,稱作A-normal形式,作為函數式編程語言中的常用IR。通過上面的圖,可以發現雖然這兩個程序的語義完全等價,文本表示也一樣(除了A-norm形式有let的前綴),但AST抽象語法樹卻不一樣。

由於程序的優化,使用了這些AST數據結構進行了變換,這兩種不同的結構,影響到最終編譯器生成的代碼。比如,想要檢測add(log(x), y)這個模式。在數據流程序中,可以首先進入add節點,然后直接檢查第一個參數是不是log。在A-form的程序中,不能直接檢查任何東西,因為add節點的輸入是%v1-需要維護一個映射表,將變量和綁定的值進行映射,然后查表才知道%v1代表的是log。

為什么可能需要Let Binding

Let Binding的一種關鍵用法,可以指定計算的scope。看一下下面這個沒有使用Let Binding的例子:

 

 

 沒有使用Let Binding編程的一個例子,當嘗試在該在哪里計算%1節點時,問題就來了。特別的是,雖然文本格式似乎建議,應該在if的scope之外,計算節點%1,但AST卻不建議這樣做。實際上數據流圖,永遠不會定義計算scope,這在語義上產生了一些歧義。

當有閉包時,這種歧義更加有趣,考慮下面的程序,該程序返回一個閉包。不知道在哪里計算%1,可以在閉包的內部和外部。

fn (%x) { %1 = log(%x) %2 = fn(%y) { add(%y, %1) } %2}Let Binding解決了這些問題,因為值的計算發生在let節點上。在這兩個程序中,如果將%1 = log(%x)改成let %v1 = log(%x),將計算位置明確指定為if scope和閉包之外。Let Binding為計算端提供了更精確的范圍,在生成后端代碼時會很有用(因為這種范圍在IR中)。

另一方面,沒有指定計算scope的數據流形式,也有其自身的優勢,不需要擔心在生成代碼時,將let放到哪里。數據流格式還為后面決定將計算節點放到哪里的Passes,提供了更大的自由度。因此,在優化的初始階段,如果發現數據流形式,還是挺方便的,那么,使用數據流圖的編碼方法,可能不是一個壞主意。目前在Relay中也實現了很多針對數據流圖的優化方式。

但是,當將IR lower到實際的運行時程序時,需要精確的計算scope。特別是當使用子函數和閉包時,要明確指定計算scope,應在哪里發生。在后期執行特定的優化中,可以使用Let Binding來解決此問題。

對IR轉換的影響

希望到目前為止,已經熟悉兩種表示形式。大多數函數式編程語言都以A-normal形式進行分析,分析人員無需注意表達式是DAG。

Relay選擇同時支持數據流形式和Let Binding。TVM相信讓框架開發者選擇熟悉的表達形式很重要。但是這確實對寫通用的Passes產生了一些影響。這里還沒介紹Passes,對Passes理解不深,沒有使用過Let表達式來構建網絡,就不繼續介紹具體有哪些影響了。

詳細內容可以參考:https://tvm.apache.org/docs/dev/relay_intro.html#let-binding-and-scopes

基於Relay構建一個自定義的神經網絡示例

基於Relay的接口定義一個Conv+BN+ReLU的小網絡,展示一下Relay接口應該如何使用,這里TVM版本是0.8.0.dev,代碼如下:

#coding=utf-8import tvmfrom tvm import relayimport numpy as npfrom tvm.contrib import graph_executor# 構造BNdefbatch_norm(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs): name = kwargs.get("name") kwargs.pop("name")ifnot gamma: gamma = relay.var(name + "_gamma")ifnot beta: beta = relay.var(name + "_beta")ifnot moving_mean: moving_mean = relay.var(name + "_moving_mean")ifnot moving_var: moving_var = relay.var(name + "_moving_var")return relay.nn.batch_norm(data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs)[0]# 構造卷積defconv2d(data, weight=None, **kwargs): name = kwargs.get("name") kwargs.pop("name")ifnot weight: weight = relay.var(name + "_weight")return relay.nn.conv2d(data, weight, **kwargs)# 構造卷積+BN+ReLU的simpleNetdefsimplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), epsilon=1e-5): conv = conv2d( data=data, channels=channels, kernel_size=kernel_size, strides=strides, padding=padding, data_layout='NCHW', name=name+'_conv') bn = batch_norm(data=conv, epsilon=epsilon, name=name + '_bn') act = relay.nn.relu(data=bn)return actdata_shape = (1, 3, 224, 224)kernel_shape = (32, 3, 3, 3)dtype = "float32"data = relay.var("data", shape=data_shape, dtype=dtype)act = simplenet(data, "graph", 32, strides=(2, 2))func = relay.Function(relay.analysis.free_vars(act), act)print(func)np_data = np.random.uniform(-1, 1, (1, 3, 224, 224))params = {"graph_conv_weight": tvm.nd.array(np.random.uniform(-1, 1, (32, 3, 3, 3)).astype(dtype)),"graph_bn_gamma": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),"graph_bn_beta": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),"graph_bn_moving_mean": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),"graph_bn_moving_var": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),}with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, "llvm", params=params)dev = tvm.cpu(0)dtype = "float32"m = graph_executor.GraphModule(lib["default"](dev))# set inputsm.set_input("data", tvm.nd.array(np_data.astype(dtype)))# executem.run()# get outputstvm_output = m.get_output(0)

就是一個很常規的過程,創建Relay Function,然后將所有的OP的權重信息用params這個字典存起來,注意這里的權重信息是隨機初始化的。在編譯Relay IR之前可以先看一下優化前的IR長什么樣:

fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var) { %0 = nn.conv2d(%data, %graph_conv_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]); %1 = nn.batch_norm(%0, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var); %2 = %1.0; nn.relu(%2)}符合第二節介紹的規則,Relay IR時一個函數。

初識Pass

上面構造simplenet的代碼中,relay.build外部包了一層tvm.transform.PassContext,如下:

with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, "llvm", params=params)實際上tvm.transform.PassContext這個接口就定義了Pass,如文檔所示:

 

 

 tvm.transform.PassContext用來控制對relay IR使用哪些Pass進行優化,Pass是TVM中基於Relay IR進行的一系列優化,類似於onnx-simplifier里面用到的onnxoptimizer,可以簡化計算圖,去除一些冗余的算子,提高模型的推理效率。TVM將所有的pass都抽象到了tvm/include/tvm/ir/transform.h這個文件中,主要包含PassContext,PassInfo,Pass,以及Sequential。

這里的PassContext是上面Python接口對應的C++實現,包含了Pass執行依賴的一些參數,如優化level,依賴特定Pass,以及設置不使用某種指定Pass等。PassInfo是用來記錄Pass信息的類,包含Pass的opy_level,name,以及當前Pass需要哪些前置Pass。而Pass這個類,就執行pass的主體,這是一個基類,每種Pass具體的C++代碼實現在tvm/src/relay/transforms中,都會繼承Pass這個基類。最后,Sequential是一個container,裝載所有Pass。

需要說明一下,不是所有的Pass都定義在tvm/src/relay/transforms,比如下面的第一個例子,就在tvm/src/relay/backend/vm文件夾里。接下來將幾個Pass的例子,到底對Relay IR做了什么?

RemoveUnusedFunctions首先來看一下定義在tvm/src/relay/backend/vm/removed_unused_funcs.cc這里的RemoveUnusedFunctions 這個pass,核心的代碼實現如下:

voidVisitExpr_(const FunctionNode* func_node)final{auto func = GetRef<Function>(func_node);if (visiting_.find(func) == visiting_.end()) { visiting_.insert(func);for (auto param : func_node->params) { ExprVisitor::VisitExpr(param); } ExprVisitor::VisitExpr(func_node->body); } }IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs){std::unordered_set<std::string> called_funcs{};for (auto entry : entry_funcs) {auto funcs = CallTracer(module).Trace(entry); called_funcs.insert(funcs.cbegin(), funcs.cend()); }auto existing_functions = module->functions;for (auto f : existing_functions) {auto it = called_funcs.find(f.first->name_hint);if (it == called_funcs.end()) {module->Remove(f.first); } }returnmodule;}

這個pass就是去除Relay IR中的冗余節點,VisitExpr_這個函數就是完成了一個圖的遍歷,然后把沒有遍歷到的節點刪掉。刪除發生在RemoveUnusedFunctions這個函數中。

ToBasicBlockNormalForm這個Pass實現在tvm/src/relay/transforms/to_basic_block_normal_form.cc,代碼實現如下:

Expr ToBasicBlockNormalFormAux(const Expr& e){// calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e);/* The scope of the whole expr is global. * The scope of any subexpr, is the lowest common ancestor of all incoming edge. * We also record the set of expressions whose scope is lifted. */std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);}IRModule ToBasicBlockNormalForm(const IRModule& mod){ DLOG(INFO) << "ToBBlock:" << std::endl << mod; tvm::Map<GlobalVar, Function> updates;auto funcs = mod->functions;for (constauto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";if (constauto* n = it.second.as<FunctionNode>()) {if (n->GetAttr<String>(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); updates.Set(it.first, Downcast<Function>(ret)); }for (auto pair : updates) { mod->Add(pair.first, pair.second, true); } DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;return mod;}boolBasicBlockNormalFormCheck(const Expr& e){// calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e);std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);for (auto expr : scopes.second) { LOG(FATAL) << "The expression below violates the basic block normal form in that " << "its scope should be lifted:\n" << expr; }return scopes.second.size() == 0;}ToBasicBlockNormalForm

這個函數通過遍歷Relay IR中的function,將每個function轉換為基本塊形式(即ToBasicBlockNormalFormAux這個函數),ToBasicBlockNormalFormAux這個函數分成以下幾個部分:

調用DependencyGraph dg = DependencyGraph::Create(&arena, e)創建一個DependencyGraph,這個數據結構是一個表達式相互依賴的圖結構。通過std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg)計算每個節點的scope,這個scope可以簡單理解為由跳轉指令如Ifnode,FunctionNode,LetNode等隔開的那些子圖,因為一旦碰到這些節點在上面通過Relay Function創建DependencyGraph就會為這種節點分配一個new_scope標志。然后CalcScope這個函數具體做了哪些事情,需要跟進去看一下:std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg){ NodeScopeMap expr_scope; ExprSet lifted_exprs;std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;// 首先讓每個節點都屬於一個單獨的scopefor (auto expr_node : dg.expr_node) { node_to_expr[expr_node.second] = expr_node.first; }bool global_scope_used = false; Scope global_scope = std::make_shared<ScopeNode>();// 使用LCA算法來更新每個節點的真正scopefor (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it;auto iit = n->parents.head; Scope s;if (iit == nullptr) { ICHECK(!global_scope_used); s = global_scope; global_scope_used = true; } else { s = expr_scope.at(iit->value);constauto original_s = s; iit = iit->next;for (; iit != nullptr; iit = iit->next) { s = LCA(s, expr_scope.at(iit->value)); }if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {// filter out exprs whose scope do not matter Expr expr = node_to_expr[n];if (!expr.as<OpNode>()) { lifted_exprs.insert(expr); } } }if (n->new_scope) {auto child_scope = std::make_shared<ScopeNode>(s); expr_scope.insert({n, child_scope}); } else { expr_scope.insert({n, s}); } } ICHECK(global_scope_used);returnstd::make_pair(expr_scope, lifted_exprs);}

這個函數首先讓每個節點都屬於一個單獨的scope,然后使用LCA算法來更新每個節點的真正scope。這里簡單介紹一下LCA算法以及這里具體是如何求取每個節點的scope的。

最近公共祖先簡稱 LCA(Lowest Common Ancestor)。兩個節點的最近公共祖先,就是這兩個點的公共祖先里面,離根最遠的那個。為了方便,記某點集 的最近公共祖先為 或 。LCA有以下性質,引自OI-wiki:

 

 

 其實不看這個性質也沒關系,了解LCA,可以求圖中兩個節點的最近公共祖先即可。然后CalcScope這個函數的具體思路,先將每個節點初始化為一個單獨的scope,然后按照后DFS序遍歷這些節點,對於每一個遍歷到的節點(這里記作n),看一下它的父親節點iit是否存在,如果不存在則說明當前節點是根節點,scope應該為global_scope。如果iit存在,那么遍歷iit的子節點,看一下這些節點的scope的LCA表達式,如果這個通過LCA求出來的表達式和iit節點的表達式完全相同,說明這個子圖和當前節點是屬於同一個scope的,否則就將當前節點插入到lifted_exprs,lifted_exprs是一個集合用來保存這個DependencyGraph里面的那些跳轉指令節點,這也是為什么上面再插入節點到lifted_exprs之前,需要判斷一下這個節點的類型是否為OpNode。另外如果當前枚舉的節點有new_scope標志,說明當前節點屬於一個新的scope,需要為當前節點分配新的類型為ScopeNode的一個智能指針。

通過上面的算法,DependencyGraph中的節點和scope節點的關系就被映射到了一個map中,並且scope節點也被建立起了一個樹結構。最后調用這個Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);來創建一個Fill類,這個類包含了DependencyGraph以及scope相關的信息,通過ToBasicBlockNormalForm成員函數實現基本塊轉換。實現在tvm/src/relay/transforms/to_a_normal_form.cc這個文件中,知乎對這個Pass也做了解釋,這里引用一下:

它(ToBasicBlockNormalForm)的基本邏輯通過VisitExpr函數遍歷dependency節點,將具有相同scope的節點壓入到同一個let_list中。Let_list文檔中是這樣解釋的:

/*! * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, * and pass them around freely without fear of AST explosion (or effect duplication). * for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'. * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);', * the AST will contain 2 'a', as b and c are now variables.

Let_list使得抽象語法樹簡潔化,不會因為變量的復制導致樹的爆炸。具有相同的scope的expr被約束到相同的let_list中,用一個var來表達,這樣就將表達式轉化為var的形式。一個var也就對應了一個基本塊。

EliminateCommonSubexpr最后再看一個消除公共子表達式的Pass,所謂公共子表達式指的就是具有相同的OP類型以及相同的參數,參數的順序都是完全相同的,這些表達式就可以合成一個公共子表達式。舉個例子:

a = b + cd = b + c

可以看到這兩個表達式時完全一致的,經過這個Pass之后,計算圖就會消除其中一個表達式。代碼實現在:tvm/src/relay/transforms/eliminate_common_subexpr.cc。這里定義了一個CommonSubexprEliminator類,這個類重載了兩個Rewrite_函數,對expr進行遍歷和重寫。代碼實現如下:

Expr Rewrite_(const CallNode* call, const Expr& post)final{staticauto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful"); Expr new_expr = post;const CallNode* new_call = new_expr.as<CallNode>(); ICHECK(new_call);const OpNode* op = new_call->op.as<OpNode>(); StructuralEqual attrs_equal;if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {return new_expr; }if (fskip_ != nullptr && fskip_(new_expr)) {return new_expr; }auto it = expr_map_.find(new_call->op);if (it != expr_map_.end()) {for (const Expr& candidate_expr : it->second) {if (const CallNode* candidate = candidate_expr.as<CallNode>()) {bool is_equivalent = true;// attrs匹配if (!attrs_equal(new_call->attrs, candidate->attrs)) {continue; }// args匹配for (size_t i = 0; i < new_call->args.size(); i++) {if (!new_call->args[i].same_as(candidate->args[i]) && !IsEqualScalar(new_call->args[i], candidate->args[i])) { is_equivalent = false;break; } }if (!is_equivalent) continue;return GetRef<Call>(candidate); } } } expr_map_[new_call->op].push_back(new_expr);return new_expr; }可以看到大概的思路就是利用expr_map_這個std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;

映射遍歷過的具有相同op的expr,然后每次碰到相同op的表達式,都會對已經記錄的expr進行匹配,匹配不僅包含OP的attrs屬性,還包含參數列表,如果完全一樣,說明這兩個表達式就是公共表達式,就不返回新的表達式。這樣就可以去掉Relay Function中的公共表達式了。

到這里可能還不是特別清楚最開始加載的那個simplenet的Relay Function,經過一些Pass之后,具體變成什么樣,其實目前也還沒搞清楚這個問題,這個問題應該就需要留到后面再解答了。

小結

本文介紹了一下TVM的Relay,介紹了如何基於Relay構建一個Conv+BN+ReLU的小網絡,介紹了一下TVM中的Pass的工作機制,詳細的介紹了RemoveUnusedFunctions,ToBasicBlockNormalForm,EliminateCommonSubexpr三種Pass。其中Relay部分的詳細介紹大部分引用自官方文檔:https://tvm.apache.org/docs/tutorials/get_started/introduction.html。

0x6. 參考資料

https://zhuanlan.zhihu.com/p/358437531https://zhuanlan.zhihu.com/p/91283238https://tvm.apache.org/docs/tutorials/get_started/introduction.html

 

https://baijiahao.baidu.com/s?id=1700872402469787364&wfr=spider&for=pc

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM