libtorch踩坑記錄



一、Linux CMakeLists鏈接版本

官網下載Linux版本libtorch的時候會發現有(Pre-cxx11 ABI)(cxx11 ABI)兩個版本。

如果鏈接(cxx11 ABI)版本需要在CMakeLists.txt中加入

add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

原因是舊版(c++03規范)的libstdc++.so,和新版(c++11規范)的libstdc++.so兩個庫同時存在,如果不加,編譯過程會報類似以下錯誤:

undefined references to `c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string

參考鏈接:https://www.codeleading.com/article/29853511199/


二、多個輸入和多個輸出

  • 輸入
    先定義一個std::vector<torch::jit::IValue>變量,然后逐個添加
// inputs
std::vector<torch::jit::IValue> inputs;

torch::jit::IValue cate = torch::ones(( 1, 1, 1 ));
std::cout << "cate: " << cate << std::endl;

torch::Tensor temp = torch::randn({ 1, 3, 11705 });
std::cout << "temp size: " << temp.sizes() << " temp 0: " << temp[0][0][0] << std::endl;
inputs.push_back(temp);
inputs.push_back(cate);
  • 輸出
    先輸出torch::jit::IValue結果,然后根據pytorch端的輸出相應得做轉變
// forward
torch::jit::IValue output = module.forward(inputs);
std::vector<torch::Tensor> outList = output.toTensorVector();   // it used to return [x1, x2, ...] from pytorch

auto tpl = output.toTuple();                                    // it used to return (x1, x2, ...) from pytorch
auto arm_loc = tpl->elements()[i].toTensor();

三、二維vector轉tensor

我的數據是二維數組,可通過以下方式轉換,其中torch::Tensor input_tensor = torch::from_blob(points.data(), { n, cols }).toType(torch::kDouble).clone();轉出來數據不對,有大佬懂的可以指點一下。

// std::vector<std::vector<double>> to torch::jit::IValue
int cols = points[0].size();
int n = points.size();

torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32);
torch::Tensor input_tensor = torch::zeros({ n, cols }, options);
for (int i = 0; i < n; i++) {
    input_tensor.slice(0, i, i + 1) = torch::from_blob(points[i].data(), { cols }, options).clone();
}

注意: 一定要有clone()復制一份, 否則數組釋放后相應數據也會釋放,共享的是同一內存


四、mask索引轉換

python C++
A[mask] torch::masked_select(A, mask)

問題
python某個算子轉換過程中需要改寫成C++,中間遇到有個轉換是mask索引賦值相關的,python下類似A[mask] = B[mask],我在C++中改寫為torch::masked_select(A, mask) = torch::masked_select(B, mask);,運行通過了,但是結果不對,排查后發現是A的值沒有變化,最終采用A = torch::_s_where(mask, B, A);替換得到解決,相關代碼如下:

  • failed

  • successed


五、常見切片操作

官方文檔==>index

參考鏈接

官網API
libtorch 常用api函數示例
LibTorch對tensor的索引/切片操作:對比PyTorch
Tensor Index API


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM