在構建一個推理模型時(如NanoDet,一個目標檢測模型),需要繼承 BasicOrtHandler。BasicOrtHandler 的初始化函數中會調用 initialize_handler() 方法,該方法會對 Ort::Env ort_env(構建在棧上)、Ort::Session ort_session(構建在堆上)等屬性進行初始化。接着深入到 Ort::Env 中,該類就定義在文件 onnxruntime/onnxruntime/core/session/onnxruntime_cxx_api.h 中,值得注意的是在這個文件中存在一個很重要的模板類 Base:
template <typename T>
struct Base {
using contained_type = T;
Base() = default;
Base(T* p) : p_{p} {
if (!p)
ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
}
~Base() { OrtRelease(p_); }
operator T*() { return p_; }
operator const T*() const { return p_; }
/// \brief Releases ownership of the contained pointer
T* release() {
T* p = p_;
p_ = nullptr;
return p;
}
protected:
Base(const Base&) = delete; // 拷貝構造:刪除
Base& operator=(const Base&) = delete; // 拷貝賦值:刪除
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } // 支持移動構造
void operator=(Base&& v) noexcept {
OrtRelease(p_);
p_ = v.release();
}
T* p_{};
template <typename>
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
};
可以看出,這里實現了類似於 unique_ptr 的功能。在該文件中,有許多類繼承被 Base 包裝后的基類,如 Env 繼承自 Base
在 Env 初始化過程中,
- 會調用 GetApi() 中的 CreateEnv()函數,在 CreateEnv()函數中會構建 LoggingManagerConstructionInfo類的實例,並傳輸到 OrtEnv::GetInstance() 函數;
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); // p_ 就是 Base<OrtEnv> 中的成員變量 OrtEnv
if (strcmp(logid, "onnxruntime-node") == 0) {
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
} else {
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
}
ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel logging_level,
_In_ const char* logid, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, logging_level, logid};
Status status;
*out = OrtEnv::GetInstance(lm_info, status);
return ToOrtStatus(status);
API_IMPL_END
}
- 在 GetInstance() 函數中,創建 onnxruntime::Environment 對象,並通過移動構造的方式構建OrtEnv對象,賦值給 OrtEnv 中的成員變量 p_instance_ 中,並返回該變量,注意:在返回變量之前,會對OrtEnv::ref_count加1,即有關OrtEnv的引用計數加1;
OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info,
onnxruntime::common::Status& status,
const OrtThreadingOptions* tp_options) {
std::lock_guard<onnxruntime::OrtMutex> lock(m_);
if (!p_instance_) {
std::unique_ptr<LoggingManager> lmgr;
std::string name = lm_info.logid;
if (lm_info.logging_function) {
std::unique_ptr<ISink> logger = std::make_unique<LoggingWrapper>(lm_info.logging_function,
lm_info.logger_param);
lmgr.reset(new LoggingManager(std::move(logger),
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&name));
} else {
auto sink = MakePlatformDefaultLogSink();
lmgr.reset(new LoggingManager(std::move(sink),
static_cast<Severity>(lm_info.default_warning_level),
false,
LoggingManager::InstanceType::Default,
&name));
}
std::unique_ptr<onnxruntime::Environment> env;
if (!tp_options) {
status = onnxruntime::Environment::Create(std::move(lmgr), env);
} else {
status = onnxruntime::Environment::Create(std::move(lmgr), env, tp_options, true);
}
if (!status.IsOK()) {
return nullptr;
}
p_instance_ = new OrtEnv(std::move(env));
}
++ref_count_;
return p_instance_;
}