1.背景(Background)
上圖顯示了目前深度學習模型在生產環境中的方法,本文僅探討如何部署pytorch模型!
至於為什么要用C++調用pytorch模型,其目的在於:使用C++及多線程可以加快模型預測速度
關於模型訓練有兩種方法,一種是直接使用C++編寫訓練代碼,可以做到搭建完整的網絡模型,但是無法使用遷移學習,而遷移學習是目前訓練樣本幾乎都會用到的方法,另一種是使用python代碼訓練好模型,並使用JIT技術,將python模型導出為C++可調用的模型,這里具體介紹第二種。(個人覺得還可以采用一種方式,即將pytorch模型作為一種Web Service以供各種客戶端調用)
官方對TorchScript的介紹如下(https://pytorch.org/docs/master/jit.html#creating-torchscript-code):
TorchScript是一種從PyTorch代碼創建可序列化和可優化模型的方法。用TorchScript編寫的任何代碼都可以從Python進程中保存並加載到沒有Python依賴關系的進程中。
我們提供了一些工具來增量地將模型從純Python程序轉換為能夠獨立於Python運行的TorchScript程序,例如,在一個獨立的c++程序中。這使得使用熟悉的工具在PyTorch中培訓模型,然后通過TorchScript將模型導出到生產環境中成為可能。在生產環境中,出於性能和多線程的原因,將模型作為Python程序運行不是一個好主意。
首先,我們在官網下載適合於Windows的libtorch
,因為穩定版出來了,所以可以直接拿來使用。有CPU版本的和GPU版本的,這里我都進行了測試,都是可以直接使用的,這里以CPU版本為例進行介紹:
2.實驗(Experiments)
1.python環境下跑模型的推斷代碼
以ESRGAN的inference code(https://github.com/xinntao/ESRGAN)為例:
環境:Windows10+Python3.5.2+Pytorch1.1
Python packages: pip install numpy opencv-python
直接run test,結果如下(我的版本有做一些改動,如增加FPS的計算等):
2.將PyTorch模型轉換為Torch Script
第一個方法是tracing.該方法通過將樣本輸入到模型中一次來對該過程進行評估從而捕獲模型結構.並記錄該樣本在模型中的flow.該方法適用於模型中很少使用控制flow的模型.
第二個方法就是向模型添加顯式注釋,通知Torch Script編譯器它可以直接解析和編譯模型代碼,受Torch Script語言強加的約束。
- 利用Tracing將模型轉換為Torch Script
要通過tracing來將PyTorch模型轉換為Torch腳本,必須將模型的實例以及樣本輸入傳遞給torch.jit.trace函數.
這將生成一個torch.jit.ScriptModule對象,並在模塊的forward方法中嵌入模型評估的跟蹤:
import torch import architecture as arch # An instance of your model. model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') model.load_state_dict(torch.load('./models/RRDB_ESRGAN_x4.pth'), strict=True) model.eval() # An example input you would normally provide to your model's forward() method. example = torch.rand(64, 3, 3, 3) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example) output = traced_script_module(torch.ones(64, 3, 3, 3)) traced_script_module.save("./models/RRDB_ESRGAN_x4_000.pt") # The traced ScriptModule can now be evaluated identically to a regular PyTorch module print(output)
跟蹤的ScriptModule可以與常規PyTorch模塊進行相同的計算,結果如下(注意在最后,將ScriptModule序列化為一個文件.然后,C++就可以不依賴任何Python代碼來執行該Script所對應的Pytorch模型.):
(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN$ python model_jit_converter.py tensor([[[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], ..., [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]]], grad_fn=<MkldnnConvolutionBackward>)
3.在C++中加載你的Script Module
要在C ++中加載序列化的PyTorch模型,您的應用程序必須依賴於PyTorch C ++ API - 也稱為LibTorch。LibTorch發行版包含一組共享庫,頭文件和CMake構建配置文件。雖然CMake不是依賴LibTorch的要求,但它是推薦的方法,並且將來會得到很好的支持。在本教程中,我們將使用CMake和LibTorch構建一個最小的C ++應用程序,它只需加載並執行序列化的PyTorch模型。
加載模塊的代碼:
#include <torch/script.h> // One-stop header. #include <iostream> #include <memory> int main(int argc, const char* argv[]) { if (argc != 2) { std::cerr << "usage: example-app <path-to-exported-script-module>\n"; return -1; } // Deserialize the ScriptModule from a file using torch::jit::load(). std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]); assert(module != nullptr); std::cout << "ok\n"; }
<torch / script.h>
頭文件包含運行該示例所需的LibTorch庫中的所有相關包含。我們的應用程序接受序列化PyTorch ScriptModule的文件路徑作為其唯一的命令行參數,然后使用torch :: jit :: load()
函數繼續反序列化模塊,該函數將此文件路徑作為輸入。作為回報,我們收到一個指向torch :: jit :: script :: Module
的共享指針,相當於C ++中的torch.jit.ScriptModule
。目前,我們只驗證此指針不為null
。我們將研究如何在接下來執行它。
LibTorch和構建應用程序
假設我們將上面的代碼保存到名為example-app.cpp的文件中。構建它的最小CMakeLists.txt如下:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(custom_ops) find_package(Torch REQUIRED) add_executable(example-app example-app.cpp) target_link_libraries(example-app "${TORCH_LIBRARIES}") set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
構建應用程序時,假設我們的示例目錄布局如下:
example-app/ CMakeLists.txt example-app.cpp
現在可以運行以下命令從example-app/文件夾中構建應用程序:
cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch make
如果一切順利,它將看起來像這樣:
(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch -- The C compiler identification is GNU 5.4.0 -- The CXX compiler identification is GNU 5.4.0 -- Check for working C compiler: /usr/bin/cc -- Check for working C compiler: /usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Check for working CXX compiler: /usr/bin/c++ -- Check for working CXX compiler: /usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- Looking for pthread.h -- Looking for pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE -- Found CUDA: /usr/local/cuda (found version "9.0") -- Caffe2: CUDA detected: 9.0 -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc -- Caffe2: CUDA toolkit directory: /usr/local/cuda -- Caffe2: Header version is: 9.0 -- Found CUDNN: /usr/include -- Found cuDNN: v7.4.1 (include: /usr/include, library: /usr/lib/x86_64-linux-gnu/libcudnn.so) -- Autodetected CUDA architecture(s): 6.1 -- Added CUDA NVCC flags for: -gencode;arch=compute_61,code=sm_61 -- Found torch: /home/anpi-cn/workspace_min/libtorch/lib/libtorch.so -- Configuring done CMake Warning at CMakeLists.txt:6 (add_executable): Cannot generate a safe runtime search path for target example-app because there is a cycle in the constraint graph: dir 0 is [/home/anpi-cn/workspace_min/libtorch/lib] dir 1 is [/usr/local/cuda/lib64/stubs] dir 2 is [/home/anpi-cn/.conda/envs/surper-resolution-pytorch/lib] dir 3 must precede it due to runtime library [libcudart.so.9.0] dir 3 is [/usr/local/cuda/lib64] dir 2 must precede it due to runtime library [libnvrtc.so.9.0] Some of these libraries may not be found correctly. -- Generating done -- Build files have been written to: /home/anpi-cn/workspace_min/Super-Resolution/ESRGAN/example-app (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ make Scanning dependencies of target example-app [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o [100%] Linking CXX executable example-app [100%] Built target example-app (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ ./example-app ../models/RRDB_ESRGAN_x4_000.pt ok
4.在C++代碼中執行Script Module
在C ++中成功加載了我們的序列化模型后,添加以下代碼到C ++應用程序的main()
函數中:
// Create a vector of inputs. std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({64, 3, 3, 3})); // Execute the model and turn its output into a tensor. auto output = module->forward(inputs).toTensor(); std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
前兩行設置了我們模型的輸入。我們創建了一個torch :: jit :: IValue
的向量並添加一個輸入。要創建輸入張量,我們使用torch :: ones()
,相當於C ++ API中的torch.ones
。然后我們運行script::Module
的forward
方法,將它傳遞給我們創建的輸入向量。作為回報,我們得到一個新的IValue
,我們通過調用toTensor()
將其轉換為張量。
在最后一行中,我們打印輸出的前五個條目。由於在前面的Python中為本次的模型提供了相同的輸入,因此理想情況下應該看到相同的輸出。重新編譯上面的應用程序並使用相同的序列化模型運行它來嘗試。通過比較,發現C++的輸出與Python的輸出是一樣的,表明實驗成功啦!
參考文章:
https://pytorch.org/tutorials/advanced/cpp_export.html
PyTorch 1.0 中文官方教程:使用 PyTorch C++ 前端
利用Pytorch的C++前端(libtorch)讀取預訓練權重並進行預測