在构建一个推理模型时(如NanoDet,一个目标检测模型),需要继承 BasicOrtHandler。BasicOrtHandler 的初始化函数中会调用 initialize_handler() 方法,该方法会对 Ort::Env ort_env(构建在栈上)、Ort::Session ort_session(构建在堆上)等属性进行初始化。接着深入到 Ort::Env 中,该类就定义在文件 onnxruntime/onnxruntime/core/session/onnxruntime_cxx_api.h 中,值得注意的是在这个文件中存在一个很重要的模板类 Base:
template <typename T>
struct Base {
using contained_type = T;
Base() = default;
Base(T* p) : p_{p} {
if (!p)
ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
}
~Base() { OrtRelease(p_); }
operator T*() { return p_; }
operator const T*() const { return p_; }
/// \brief Releases ownership of the contained pointer
T* release() {
T* p = p_;
p_ = nullptr;
return p;
}
protected:
Base(const Base&) = delete; // 拷贝构造:删除
Base& operator=(const Base&) = delete; // 拷贝赋值:删除
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } // 支持移动构造
void operator=(Base&& v) noexcept {
OrtRelease(p_);
p_ = v.release();
}
T* p_{};
template <typename>
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
};
可以看出,这里实现了类似于 unique_ptr 的功能。在该文件中,有许多类继承被 Base 包装后的基类,如 Env 继承自 Base
在 Env 初始化过程中,
- 会调用 GetApi() 中的 CreateEnv()函数,在 CreateEnv()函数中会构建 LoggingManagerConstructionInfo类的实例,并传输到 OrtEnv::GetInstance() 函数;
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); // p_ 就是 Base<OrtEnv> 中的成员变量 OrtEnv
if (strcmp(logid, "onnxruntime-node") == 0) {
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
} else {
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
}
ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel logging_level,
_In_ const char* logid, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, logging_level, logid};
Status status;
*out = OrtEnv::GetInstance(lm_info, status);
return ToOrtStatus(status);
API_IMPL_END
}
- 在 GetInstance() 函数中,创建 onnxruntime::Environment 对象,并通过移动构造的方式构建OrtEnv对象,赋值给 OrtEnv 中的成员变量 p_instance_ 中,并返回该变量,注意:在返回变量之前,会对OrtEnv::ref_count加1,即有关OrtEnv的引用计数加1;
OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info,
onnxruntime::common::Status& status,
const OrtThreadingOptions* tp_options) {
std::lock_guard<onnxruntime::OrtMutex> lock(m_);
if (!p_instance_) {
std::unique_ptr<LoggingManager> lmgr;
std::string name = lm_info.logid;
if (lm_info.logging_function) {
std::unique_ptr<ISink> logger = std::make_unique<LoggingWrapper>(lm_info.logging_function,
lm_info.logger_param);
lmgr.reset(new LoggingManager(std::move(logger),
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&name));
} else {
auto sink = MakePlatformDefaultLogSink();
lmgr.reset(new LoggingManager(std::move(sink),
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&name));
}
std::unique_ptr<onnxruntime::Environment> env;
if (!tp_options) {
status = onnxruntime::Environment::Create(std::move(lmgr), env);
} else {
status = onnxruntime::Environment::Create(std::move(lmgr), env, tp_options, true);
}
if (!status.IsOK()) {
return nullptr;
}
p_instance_ = new OrtEnv(std::move(env));
}
++ref_count_;
return p_instance_;
}