python的易上手和pytorch的動態圖特性,使得pytorch在學術研究中越來越受歡迎,但在生產環境,礙於python的GIL等特性,可能達不到高並發、低延遲的要求,存在需要用c++接口的情況。除了將模型導出為ONNX外,pytorch1.0給出了新的解決方案:pytorch 訓練模型 - 通過torch script中間腳本保存模型 -- C++加載模型。最近工作需要嘗試做了轉換,總結一下步驟和遇到的坑。
用torch script把torch模型轉成c++接口可讀的模型有兩種方式:trace && script. trace比script簡單,但只適合結構固定的網絡模型,即forward中沒有控制流的情況,因為trace只會保存運行時實際走的路徑。如果forward函數中有控制流,需要用script方式實現。
trace顧名思義,就是沿着數據運算的路徑走一遍,官方例子:
import torch def foo(x, y): return 2*x + y traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) |
script稍復雜,主要改三處:
1. Model由之前繼承 nn.Model 改為繼承 torch.jit.ScriptModule
2. forward函數前加 @torch.jit.script_method
3. 其他需要調用的函數前加 @torch.jit.script
踩過的坑&&解決方法:
A. torch script默認函數或方法的參數都是Tensor類型的,如果不是需要說明,不然調用非Tensor參數時會報類型不符的編譯錯誤。
python3可以直接:
| def example_func(param_1: Tensor, param_2: int, param_3: List[int]): |
python2需要用type注釋:
| def example_func(param_1, param_2, param_3): #type: (Tensor, int, List[int]) -> Tensor |
B. model的方法中forward加@torch.jit.script_method, __init__函數不用
C. 前面說過,torch scrip支持的函數是pytorch的子集,意味着有一部分函數不支持,例如: not boolean,pass, List的切片賦值,CPU和GPU切換的value.to( ), 需要想辦法繞過去。看github上討論區說新版好像已經支持not操作了,沒有驗證。
結論:pytorch 1.0目前的預覽版還有比較多優化的空間,至少是在torch script支持的函數集合上,不建議使用,等穩定版發布再看看吧。
原創內容,轉載請注明出處。
參考資料:
https://pytorch.org/docs/master/jit.html
https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html
