AI中pass架構設計優化


AI中pass架構設計優化

Relay 和 TVM IR,包含一系列優化passes,可提高模型的性能指標,例如平均推理,內存占用,或特定設備的功耗。有一套標准優化,及特定機器學習的優化,包括常量折疊,死代碼消除,算子布局更改,算子融合,緩沖區處理和循環轉換等。這些passes中的每一個都構造為一個 ir-to -ir 轉換,使用在遍歷期間和/或前收集的分析結果。

隨着 TVM 的快速發展,對管理這些pass的更系統,更有效的方法的需求,變得越來越明顯。此外,管理跨 TVM 堆棧不同層(例如 Relay 和 tir)pass的通用框架,為開發人員快速構建原型,將實現的pass插入系統,鋪平了道路。

本節描述了基礎架構設計,利用產品編譯器,管理優化pass,及構建層的深度學習框架。

例如,許多現有的產品編譯器,如 GCC 和 LLVM,都采用pass管理器,有效管理pass的執行。最初管理 pass 很簡單,因為 pass 的數量很少,成熟的編譯器,將包含數百個單獨的 pass。通常,外部用戶希望正確調度自定義pass,無需修改單個手工制作的pass順序。

現代深度學習框架,如 Pytorch 和 MXNet Gluon,分別通過SequentialBlock,啟用pass-style 層構建方案的趨勢。有了這樣的結構,這些現代框架能夠方便將模塊/層添加到容器中,輕松地構建神經網絡。

Relay pass infra 的設計,很大程度上受到 LLVM 中,使用的分層pass管理器和流行的深度學習框架中,使用的塊式容器的啟發。pass基礎架構的主要目標包括:

  1. 實現更好的優化編程調度。允許用戶靈活地定制和構建優化pass。
  2. 提供一種用戶友好的方式來調試優化pass。
  3. 減輕開發人員手動和分別解決pass之間的依賴關系。
  4. 為開發人員簡化新pass的實施。例如,允許用戶在 Python 中,實現一個 pass,讓 pass infra 操縱執行。

設計

專注於為用戶提供易於擴展的功能,讓用戶可以快速添加新pass,不會失去向后兼容性。該設計包含后端和前端。前者實現了 pass infra 的主要邏輯。后者為用戶提供了簡單的 API 進行交互,允許用戶快速創建優化pass。

C++ 后端

提供一個PassInfo對象,包含pass所需的基本信息。name是傳遞名稱,opt_level指示,在哪個優化級別啟用pass, required表示執行特定pass,所需的pass(有關更多詳細信息,參閱include/tvm/ir/transform.h)。例如,在注冊pass期間,pass開發人員,可以指定pass的名稱,將執行的優化級別和/或所需的pass。在用戶提供的優化級別下運行時,是否需要執行某個 pass, opt_level可用於幫助 pass infra 識別。 required字段,可以被 pass infra 使用,解決 pass 依賴。

class PassInfoNode : public Object {
  String name;
  int opt_level;
  Array<String> required;
};

傳遞上下文

PassContext攜帶用於優化pass的有用信息。例如,包含錯誤報告系統,可以提供有關優化失敗原因的診斷。PassContext旨在替換舊的BuildConfig,用於幫助用戶配置編譯選項,包括優化級別和必需/禁用的pass等。例如,可能有一個配置, opt_level=3使用disabled_pass=xx提供的某些禁用的pass,執行所有PassContext。可以將所有pass,放在opt_level=3,排除禁用pass列表中的那些。PassContext提供了一種檢測所有pass的方法。

PassContext包含優化pass的有用信息。例如,包含錯誤報告系統,可以提供有關優化失敗原因的診斷。PassContext設計用於替換舊的BuildConfig,該配置用於幫助用戶配置編譯選項,包括優化級別和必需/禁用的pass等。例如,可能有一個配置,該配置使用PassContext提供的disabled_pass=xx,在opt_level=3執行所有pass,一些禁用的pass,使用disabled_pass=xx。現在,可以在opt_level=3時,對所有pass,進行全局排序,排除禁用pass列表中的pass。PassContext提供了一種對所有pass,進行檢測的方法。

這個類是為用戶設計的,使用語法編寫Python,在特定配置下執行優化。用戶可以通過PassContext::Current(),線程安全的方式,獲得特定程序范圍內,可用的上下文,線程本地存儲PassContextThreadLocalStore,保存創建的pass context對象。將提供示例來說明,如何使用C++和Python API,使用pass context,創建編譯pass。

class PassContextNode: public Object {
 public:
  int opt_level{2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
  mutable Optional<DiagnosticContext> diag_ctx;
  Map<String, ObjectRef> config;
  Array<instrument::PassInstrument> instruments;
};
 
class PassContext : public NodeRef {
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  TVM_DLL void InstrumentEnterPassContext();
  TVM_DLL void InstrumentExitPassContext();
  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
  /* Other fields are omitted. */
 
 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();
 
  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};
 
struct PassContextThreadLocalEntry {
  /*! \brief The default pass context. */
  PassContext default_context;
  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
    default_context = PassContext(make_node<PassContextNode>());
  }
};
 
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

pass構建

pass infra以分層方式設計,可以在 Relay/tir 程序的,不同粒度下工作。PassNode引入了一個純虛擬類,作為不同優化pass的基礎。包含幾個必須由子類在模塊,函數或pass序列級別實現的虛擬方法。

class PassNode : Object {
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

函子顯示必須如何實現pass,始終在 IRModule特定上下文下工作。所有pass都以ModuletoModule方式設計。由 pass infra 控制的優化,將始終更新整個模塊。

已經創建了幾個子類,實現不同類型的優化pass,例如,函數級pass,模塊級pass和順序pass。每個子類本身都可以充當pass管理器。例如,可以收集所需的pass執行,或基於給定的元數據構建,依賴關系圖。完整定義可以在src/relay/ir/transform.ccsrc/ir/transform.cc 中找到

模塊級pass

模塊級pass主要用於全局和pass間優化 (IPO),類似於 LLVM 中使用的模塊pass。Relay 中一些典型的 pass,需要一個模塊的全局圖片,比如 A-normal form 轉換和 lambda 提升等,都屬於這個集合。在此級別,用戶甚至可以在模塊中,添加和/或刪除功能。所有pass

class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

pass_info維護模塊級pass所需的信息。pass_func勾勒出真正的優化。例如,可能需要對模塊執行死代碼消除。可以在pass_func中實現算法,在模塊上運行。刪除死代碼,包括模塊中未使用的函數。該字段被設計為一個打包函數,可以在 C++ 和 Python 中,實現優化。

函數級pass

函數級pass,用於為給定的 Relay/tir 模塊,實現各種函數內級優化。一次從模塊的函數列表中,獲取一個函數,進行優化,生成重寫的 Relay Function或 tir PrimFunc。大多數pass可以歸入這一類,如Relay中,常見子表達式消除和推理簡化,及tir中的矢量化和扁平化存儲等。

此級別的pass范圍是 Relay 函數,或 tir 原始函數。無法通過pass,添加或刪除函數。

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

pass_info與剛剛在模塊pass中描述的相同。pass_func需要一個函數,進行優化,需要一個模塊,可能會報告錯誤。一個函數可以用“SkipOptimization”注釋,在優化pass中忽略。

連續passes

SequentialPass與 Pytorch 類似,nn.Sequential包含許多用於執行的pass。

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

僅放置了在Relay中的少數pass。例如,FoldScaleAxis要求在內部調度ForwardFoldScaleAxis和BackwardFoldScaleAxis。建議首先完成BackwardFoldScaleAxis。該pass是SequentialPass的理想候選。

下面的代碼顯示了如何調用順序過程中的各個pass。使用pass列表中,在一個順序pass中,執行每個pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在調用pass時,先檢查是否啟用了pass。檢查用戶是否明確禁用pass,是否被用戶指定為必需pass完成的。如果不確定,是否啟用了此pass,opt_level將進行檢查。只有當優化級別不低於PassContext中,配置的優化級別時,才會啟用執行pass。

要執行pass,先需要使用pass名稱,在 TVM 打包函數注冊表中,檢索已注冊的pass。這是可能的,因為每個pass,都注冊了一個 API 端點,將在后面展示。

Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
  std::string fpass_name = "relay._transform." + pass_name;
  const auto* f = Registry::Get(fpass_name);
  ICHECK(f != nullptr) << "Cannot find " << fpass_name
                      << "to create the pass " << pass_name;
  return (*f)();
}

提供了一些輔助函數,創建上述每種類型的pass。這些幫助程序,暴露給 Python 前端,使用 Python API,創建特定的 pass 對象。

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);
 
Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);
 
Pass CreateModulePass(
    const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);
 
Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

pass注冊

不同級別pass的概念和用於編譯的context,可以輕松注冊pass,以 const 折疊為例。這個pass已經實現,折疊 Relay 函數中的常量(在 src/relay/transforms/fold_constant.cc 中找到)。

提供了一個 API,執行ExprtoExpr轉換。

Expr FoldConstant(const Expr& expr);

為了將這個pass注冊到pass infra,先需要決定這個pass,在哪個級別執行。由於常量折疊,發生在單個函數上,應該直觀FunctionPass通過 CreateFunctionPass. 將pass_func作為打包函數返回,該函數在IRModule 中的每個函數上調用Exprto ExprAPI。{}表示此pass,不需要先決條件。否則,pass開發人員必須識別列出。

使用名稱 relay._transform.FoldConstant,注冊一個pass API 端點 。這個pass成為注冊表中的一個條目,可以由C++(如GetPass上面的)和Python訪問。

namespace transform {
 
Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
 
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
 
}  // namespace transform

為了允許其它 C++ 模塊應用這個pass,在include/tvm/relay/transform.h 中聲明了一個自由函數, 如下所示:

TVM_DLL Pass FoldConstant();

pass儀器

Pass Instrument 是一種分析pass本身的機制。例如,可以使用基礎架構,了解一次pass需要多少時間和內存,或者一次pass,如何轉換 IR 模塊。

生命周期中的四個儀器點PassContext。

TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;

當輸入PassContext實例的范圍時,立即調用InstrumentEnterPassContext。

InstrumentExitPassContext在離開PassContext的作用域時被調用,或者在執行過程中發生異常。當tvm.transform.PassContext中的OverrideU instruments重寫儀器時,會調用此方法。

在執行前,調用InstrumentBeforePass。如果運行pass,在執行后調用InstrumentAfterPass。這種行為就像:

if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
  new_ir_module = run_pass(ir_module, pass_ctx);
  pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
  return new_ir_module;
}

該PassInstrument接口允許在上述四種方法中運行任意代碼。多個PassInstrument實例,可以注冊到一個 PassContext。PassInstrument實例按照instruments傳遞給參數序列,依次調用 PassContext。

PassInstrument 提供以下接口:

namespace instrument {
 
class PassInstrumentNode : public Object {
 public:
  String name;
  virtual void EnterPassContext() const = 0;
  virtual void ExitPassContext() const = 0;
  virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  /* Other fields are omitted. */
};
 
class PassInstrument : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};
 
}  // namespace instrument

提供Python前端,以PassInstrument快速實現。

在 PassContext中, PassInstrument實例的調用順序是這樣的:

with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
    pi.EnterPassContext()
 
    if pi.ShouldRun(Pass1):
        pi.RunBeforePass()
        Pass1()
        pi.RunAfterPass()
 
    if pi.ShouldRun(Pass2):
        pi.RunBeforePass()
        Pass2()
        pi.RunAfterPass()
 
    pi.ExitPassContext()

介紹一下PassInstrument接口和PassContext方法的關系。有關更多詳細信息,參閱 ( src/ir/transform.cc )。

  • InstrumentEnterPassContext
    • EnterPassContext()按instruments傳遞給PassContext 的順序執行。
    • 當異常發生時,PassContext通過清除所有注冊的PassInstrument實例,禁用pass檢測。
    • PassContext執行ExitPassContext(),成功完成的每個PassInstrument實例的方法EnterPassContext()
    • 例如,如果PassInstrumentA,B,C,注冊到 PassContext,A 完成,EnterPassContext(),B 拋出異常, C 永遠不會執行;ExitPassContext()A 的執行。
  • InstrumentExitPassContext
    • 每個PassInstrument的實例ExitPassContext(),執行順序是instruments傳遞給PassContext.
    • 當異常發生時,instruments被清除。
    • PassInstrument拋出異常后,注冊的實例,不執行ExitPassContext。
  • InstrumentBeforePass
    • 如果pass未列為必需pass,執行ShouldRun。
    • RunBeforePass如果傳球沒有被ShouldRun 阻擋,按照instruments的順序執行。
    • InstrumentBeforePass返回一個布爾值,指示是否應該運行傳遞。
    • 當異常發生時,立即拋出。依靠 Python Context ManagerPassContext安全退出(ExitPassContext每個儀器都會運行的含義。對於 C++,參閱include/tvm/support/with.h。)
  • InstrumentAfterPass
    • RunAfterPass按instruments傳遞給 PassContext的順序執行。
    • 當異常發生時,立即拋出。依靠 Python Context Manager 或Withclass( include/tvm/support/with.h ),安全退出PassContext

build儀器

有幾種內置工具,標有TODO 的,沒有實現。

  • PassTimingInstrument(參見src/ir/instrument.cc
    • 分析pass的執行時間。
  • PrintIRBefore(TODO)
    • 在pass轉換前,打印 IR 模塊。如果在pass周圍插入,tvm.transform.PrintIR()可以達到這個目的。但是,使用PassInstrument,不需要修改passes的順序。
  • 打印后(待辦事項)
    • 在pass轉換后,打印 IR 模塊。

Python前端

前端只需要一些簡單的 API。例如,可以為用戶提供以下 API,創建和執行一個 pass(完整的實現在python/tvm/relay/transform/transform.py和 python/tvm/ir/transform.py 中提供)。后端接收信息,決定應該使用哪個函數,創建 Pass 對象。

PassContext

Python 前端為__enter____exit__current 提供了一個包裝器,通過覆蓋和PassContext,啟用with語法。為用戶提供了一種靜態方法,獲取在一定范圍內使用的Context。

@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self
 
    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)
 
    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

PassContext用於配置編譯選項,包括優化級別和必需/禁用的pass。可以帶一個配置字典,以便不同的pass,可以方便地獲取pass的數據,如回退設備信息和循環展開的步驟/深度等。為了能夠獲取所需的配置,必須通過 TVM_REGISTER_PASS_CONFIG_OPTION注冊密鑰。例如,使用以下內容,循環展開pass

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多細節,參考src/tir/transforms/unroll_loop.cc

pass對象

Pass是所有pass對象的基類。這里的所有方法,都只是在后端實現的簡單包裝器。為了用戶方便地與 Python 中的基類,進行交互定義的。在 pass 基類中只定義了__call__,使子類成為可調用對象,可以很容易調用(例如,pass_xx(arg))執行。

@register_relay_node
class Pass(RelayNode):
   def __call__(self, mod):
       return _transform.RunPass(self, mod)

提供了一些輔助 API,從 Python 前端,輕松創建pass,讓pass基礎控制執行。例如module_pass,function_pass和sequential提供給用戶,以便可以定制pass。

對於在 C++ 后端實現的所有 pass,分別在python/tvm/ir/transform.py和 python/tvm/relay/transform/transform.py 中,提供了相應的 Python API 。例如,const 折疊有一個 Python API,如下所示:

def FoldConstant():
    return _transform.FoldConstant()

可以構建一個pass through裝飾:

 @relay.transform.module_pass(opt_level=2)
 def transform(mod, ctx):
    tp = relay.TensorType((10,), "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("abs")
    func = relay.Function([x], relay.abs(x))
    new_mod = tvm.IRModule({gv: func})
    new_mod.update(mod)
    return new_mod
 
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

在transform功能增加了一個abs與輸入模塊的功能,可能是在模塊級的任何定制的優化。創建module_pass后,應用於任何 Relay 模塊。例如,可以構建一個空模塊,應用pass添加一個abs 函數。

mod = tvm.IRModule()
mod = module_pass(mod)

提供function_pass功能,一個示例函數級pass,可以寫成如下:

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
   def __init__(self, new_func):
      self.new_func = new_func
      def transform_function(self, func, mod, ctx):
         # Just for demo purposes
         # Transform func to new_func
         return self.new_func
 
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)

可以不使用裝飾器,直接注冊pass,調用。有關如何自定義優化pass,調試 Relay 和 tir pass 的更多示例,參閱 use pass infra教程。

pass儀器

可以通過在實現以下方法的類上,使用pass_instrument decorator(python/tvm/ir/instrument.py),實現PassInstrument。建議使用pass_instrument decorator,實現PassInstrument,不是重寫或子類化。

  • enter_pass_ctx
    • 該方法在進入PassContext時運行。
  • exit_pass_ctx
    • 此方法在退出PassContext時運行。
  • should_run
    • 此方法在執行pass前運行,返回一個布爾值,指示是否應運行pass。
  • run_before_pass
    • 如果應該運行一次pass,在pass執行之前運行此方法。
  • run_after_pass
    • 此方法在執行一次pass后,立即運行。

PassInstrument實例可以通過tvm.transform.PassContext中的參數 instruments注冊。

use pass instrument提供了如何使用 Python API,實現PassInstrument的示例。

覆蓋當前 PassContext 中的儀器

提供了current PassContext覆蓋instruments 的override_instruments方法。例如,如果在沒有顯式創建 new 的情況下,運行 pass PassContext,可以通過以下方式注冊PassInstrument到全局中PassContext:

cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()

當override_instruments調用時,舊PassInstrument實例的方法exit_pass_ctx會調用。然后new PassInstrument的enter_pass_ctx方法調用。

 

 

參考鏈接:

https://tvm.apache.org/docs/dev/pass_infra.html#pass-infra


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM