【tvm解析】PACKFUNC機制


為實現多種語言支持,需要滿足以下幾點:

  • 部署:編譯結果可以從python/javascript/c++調用。
  • Debug: 在python中定義一個函數,在編譯函數中調用。
  • 鏈接:編寫驅動程序以調用設備特定代碼(如CUDA),可以在編譯的host側調用
  • ‎原型:python側定義IR PASS,並從C++后端調用該代碼‎
  • 接口暴露:c++后端代碼暴露到python側
  • ‎實驗:將編譯的函數運送到嵌入式設備,可以直接在嵌入式設備上運行

tvm希望在任何一個語言中定義的函數,可以在其他的語言中都可以調用。同樣希望runtime盡可能的輕量化,以方便在嵌入式設備上部署。

PackedFunc

PackedFunc是解決上述問題的一個優雅的方案。一個PackedFunc對象對應着一個函數調用,即使定義與調用分散在不同語言之間也可以滿足。下面展示一個C++的例子。

#include <tvm/runtime/packed_func.h>

void MyAdd(TVMArgs args, TVMRetValue* rv) {
  // automatically convert arguments to desired type.
  int a = args[0];
  int b = args[1];
  // automatically assign value return to rv
  *rv = a + b;
}

void CallPacked() {
  PackedFunc myadd = PackedFunc(MyAdd);
  // get back 3
  int c = myadd(1, 2);
}

上面的例子中,定義了一個MyAddPackedFunc,接受兩個參數,args表示輸入參數, rv表示返回值。這個參數是類型無關的(type-erased),這意味着函數簽名中對輸入輸出參數的類型沒有限制。這樣,當調用這個函數的時候, 從棧上獲取輸入參數(TVMArgs),通過TVMRetValue返回函數返回值。

通過C++的模板技巧,可以像正常函數一樣調用PackedFunc。由於類型無關的特性,可以在像python這樣的動態類型的語言中調用PackedFunc,而無需插入額外其他的膠水代碼。下面展示了PackedFunc 的注冊及其在python端的調用。

// register a global packed function in c++
TVM_REGISTER_GLOBAL("myadd")
.set_body(MyAdd);
import tvm

myadd = tvm.get_global_func("myadd")
# prints 3
print(myadd(1, 2))

多數的PackedFunc技巧依賴於TVMArgsTVMRetValue,我們限制其中的參數類型,下面是主要用的類型:

  • int, float and string
  • PackedFunc itself
  • Module for compiled modules
  • DLTensor* for tensor object exchange
  • TVM Object to represent any object in IR

這個限制,使得實現及其簡單而且無需序列化操作。雖然增加了限制,但對於DL開發來說,大多數場景下僅僅需要傳遞DLTensor和數字就夠了。

既然PackedFunc可以將另外的PackedFunc作為函數參數,那就可以在python與c++之間傳遞函數。

TVM_REGISTER_GLOBAL("callhello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  PackedFunc f = args[0];
  f("hello world");
});
import tvm

def callback(msg):
  print(msg)

# convert to PackedFunc
f = tvm.convert(callback)
callhello = tvm.get_global_func("callhello")
# prints hello world
callhello(f)

TVM 提供了極簡的C API,使得將PackedFunc可以方便地嵌入到其他的語言中。除python外,還支持java、JavaScript。

PackFunction不僅用於tvm編譯器中,同樣也用於開發的技術棧中。在tvm中所有的PASS函數都通過PackedFunc暴露給前端的。編譯結果同樣是通過PackedFunc打包的。

為了保證runtime盡可能的小,runtime中隔離了IR對象的支持。這使得runtime大小只有200~600k,具體的大小取決於平台驅動部分。

PackedFunc帶來的調用開銷很小,僅僅是通過棧傳遞了一些參數對象,只要不通過它包裝較小的函數,就是OK的。總之,PackedFunc是tvm中通用的膠水代碼,支持了tvm的編譯部署。

額外的部分:

c++ 注冊,python調用

上文中介紹注冊時,使用到了一個C++宏TVM_REGISTER_GLOBAL,這里介紹中間是如何鏈接起來的。

TVM_REGISTER_GLOBAL("callhello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  PackedFunc f = args[0];
  f("hello world");
});

//展開就是
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register("callhello").set_body([](TVMArgs args, TVMRetValue* rv) {
  PackedFunc f = args[0];
  f("hello world");
});

這里的::tvm::runtime::Registry::Register

Registry& Registry::Register(const std::string& name, bool can_override) {  // NOLINT(*)
  Manager* m = Manager::Global();//這是個靜態對象,Manager持有一個map來記錄注冊對象
  std::lock_guard<std::mutex> lock(m->mutex);
  if (m->fmap.count(name)) {
    ICHECK(can_override) << "Global PackedFunc " << name << " is already registered";
  }

  Registry* r = new Registry();
  r->name_ = name;
  m->fmap[name] = r;
  return *r;
}

下面看下Registry的實現。

/*! \brief Registry for global function */
class Registry {
 public:
  //設置函數體
  TVM_DLL Registry& set_body(PackedFunc f);  // NOLINT(*)
  Registry& set_body(PackedFunc::FType f) {  // NOLINT(*)
    return set_body(PackedFunc(f));
  }
  
  //給一個任意函數,萃取函數簽名
  template <typename FLambda>
  Registry& set_body_typed(FLambda f) {
    using FType = typename detail::function_signature<FLambda>::FType;
    return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed());
  }
  //給一個類成員函數、返回值、參數,使用lambda包裝
  template <typename T, typename R, typename... Args>
  Registry& set_body_method(R (T::*f)(Args...)) {
    auto fwrap = [f](T target, Args... params) -> R {
      // call method pointer
      return (target.*f)(params...);
    };
    return set_body(TypedPackedFunc<R(T, Args...)>(fwrap, name_));
  }

  template <typename T, typename R, typename... Args>
  Registry& set_body_method(R (T::*f)(Args...) const) {
    auto fwrap = [f](const T target, Args... params) -> R {
      // call method pointer
      return (target.*f)(params...);
    };
    return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_));
  }
  //
  template <typename TObjectRef, typename TNode, typename R, typename... Args,
            typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
  Registry& set_body_method(R (TNode::*f)(Args...)) {
    auto fwrap = [f](TObjectRef ref, Args... params) {
      TNode* target = ref.operator->();
      // call method pointer
      return (target->*f)(params...);
    };
    return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
  }

  template <typename TObjectRef, typename TNode, typename R, typename... Args,
            typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
  Registry& set_body_method(R (TNode::*f)(Args...) const) {
    auto fwrap = [f](TObjectRef ref, Args... params) {
      const TNode* target = ref.operator->();
      // call method pointer
      return (target->*f)(params...);
    };
    return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
  }

  TVM_DLL static Registry& Register(const std::string& name, bool override = false);  // NOLINT(*)
  
  TVM_DLL static bool Remove(const std::string& name);
  
  TVM_DLL static const PackedFunc* Get(const std::string& name); 
  TVM_DLL static std::vector<std::string> ListNames();

  struct Manager;

 protected:
  std::string name_;
  PackedFunc func_;
  friend struct Manager;
};

上面注冊以后是在一個全局對象中,下一部就看python側如何調用的。

python端最終會調用到 _get_global_func函數,具體實現如下。

def _get_global_func(name, allow_missing=False):
    handle = PackedFuncHandle()
    check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))

    if handle.value:
        return _make_packed_func(handle, False)

    if allow_missing:
        return None

    raise ValueError("Cannot find global function %s" % name)

進而會調用到TVMFuncGetGlobal

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
  API_BEGIN();
  const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);
  if (fp != nullptr) {
    *out = new tvm::runtime::PackedFunc(*fp);  // NOLINT(*)
  } else {
    *out = nullptr;
  }
  API_END();
}

這里既可以發現tvm::runtime::Registry::Get(name)來查找相關注冊函數的。

python注冊,c++ 調用

如下面的函數,通過裝飾器注冊。

@tvm._ffi.register_func("relay.backend.lower_call")

在c++中調用

static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");

下面介紹以下python的注冊。

def register_func(func_name, f=None, override=False):
    if callable(func_name):
        f = func_name
        func_name = f.__name__

    if not isinstance(func_name, str):
        raise ValueError("expect string function name")

    ioverride = ctypes.c_int(override)

    def register(myf):
        """internal register function"""
        if not isinstance(myf, PackedFuncBase):
            myf = convert_to_tvm_func(myf) #轉化為packfunc
        #注冊
        check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride))
        return myf

    if f:
        return register(f)
    return register
def convert_to_tvm_func(pyfunc):
    local_pyfunc = pyfunc

    def cfun(args, type_codes, num_args, ret, _):
        """ ctypes function """
        num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
        pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
        # pylint: disable=broad-except
        try:
            rv = local_pyfunc(*pyargs)
        except Exception:
            msg = traceback.format_exc()
            msg = py2cerror(msg)
            _LIB.TVMAPISetLastError(c_str(msg))
            return -1

        if rv is not None:
            if isinstance(rv, tuple):
                raise ValueError("PackedFunction can only support one return value")
            temp_args = []
            values, tcodes, _ = _make_tvm_args((rv,), temp_args)
            if not isinstance(ret, TVMRetValueHandle):
                ret = TVMRetValueHandle(ret)
            if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0:
                raise get_last_ffi_error()
            _ = temp_args
            _ = rv
        return 0

    handle = PackedFuncHandle()
    f = TVMPackedCFunc(cfun)
    # NOTE: We will need to use python-api to increase ref count of the f
    # TVM_FREE_PYOBJ will be called after it is no longer needed.
    pyobj = ctypes.py_object(f)
    ctypes.pythonapi.Py_IncRef(pyobj)
    if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
        raise get_last_ffi_error()
    return _make_packed_func(handle, False)
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {
  API_BEGIN();
  tvm::runtime::Registry::Register(name, override != 0)
      .set_body(*static_cast<tvm::runtime::PackedFunc*>(f));
  API_END();
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM