【PyTorch基礎】將pytorch模型轉換為script模型


 

操作步驟:

1. 將PyTorch模型轉換為Torch腳本;

1)通過torch.jit.trace轉換為torch腳本;

2)通過torch.jit.script轉換為torch腳本;

2. 將腳本模型序列化為文件;

要想序列化模型文件,只需在模塊上調用save函數即可;

3. 在c++中加載腳本模塊;

安裝使用LibTorch;

使用torch::jit::load()函數對該模塊進行反序列化,得到一個torch::jit::script::Module對象。

4. 在c++中執行腳本模塊;

注意,生成序列化和調用反序列化模型的輸入必須要保持一致;

code

# -*- coding: utf-8 -*-
# @Time  : 2021.07.27 16:00
# @Author: xxx
# @Email : 
# @File  : torch2script.py
"""
Transform torch model to Script module.
"""
import torch
from unet import UNet
from config import UNetConfig

cfg = UNetConfig()
model_path = './checkpoints/epoch_500.pth'
# model
model = UNet(cfg)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
# an example input.
example = torch.rand(5, 3, 625, 620)  # NCHW.
# Trace to Torch script.
# Use torch.jit.trace to generate a troch.jit.scriptmodule via tracing.
# 將 PyTorch 模型通過跟蹤轉換為 Torch 腳本,必須將模型的實例以及示例輸入傳遞給torch.jit.trace函數。
# 這將產生一個torch.jit.ScriptModule對象,並將模型評估的軌跡嵌入到模塊的forward方法中.
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(example)
output1= model(example)
traced_script_module.save('./unet_trace_module.pt')
# print('output:  ', output)
# print('output1: ', output1)
print('traced_script_module graph: \n', traced_script_module.graph)
print('traced_script_module code : \n', traced_script_module.code )

# ERROR!!!!!
# # Script module
# model_script = UNet(cfg)
# sm = torch.jit.script(model_script)
# output2 = sm(example)
#
# # Serialize model.
# sm.save('./unet_script_module.pt')

 注意,執行腳本模型文件進行測試的輸入大小必須和生成腳本模型的輸入大小一致,否則執行的時候會出錯;

error

/home/xxx/lib/python3.8/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/xxx/lib/python3.8/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/xxx/lib/python3.8/site-packages/torch/jit/__init__.py(1109): trace_module
/home/xxx/lib/python3.8/site-packages/torch/jit/__init__.py(953): trace
torch2script.py(25): <module>
RuntimeError: Sizes of tensors must match except in dimension 1. Got 78 and 79 in dimension 3 (The offending index is 1)

Aborted (core dumped)

 5. CUDA相關函數

  std::cout <<"torch::cuda::is_available():" << torch::cuda::is_available() << std::endl;
  std::cout <<"torch::cuda::cudnn_is_available():" << torch::cuda::cudnn_is_available() << std::endl;
  std::cout <<"torch::cuda::device_count():" << torch::cuda::device_count() << std::endl;

6. GPU/CPU模式

torch::DeviceType device_type = at::kCPU; // 定義設備類型
if (torch::cuda::is_available())
    device_type = at::kCUDA;
model.to(device_type);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({ 1, 3, 224, 224 }).to(device_type));

 device

    torch::DeviceType device_type;
    device_type = torch::kCUDA;
    torch::Device device(device_type);
    torch::jit::script::Module module = torch::jit::load(model_path, device);

 7. 注意,需要對ScriptModule的結果進行驗證和評估,使其與常規 PyTorch 模塊的推斷結果相同;

    注意,使用no_grad()進行驗證評估;

with.no_grad():

 8. 在c++中加載torchscript模型的時候,發現輸入尺寸不必和torchscript模型的尺寸一致;

 

參考

1. 在 C++ 中加載 TorchScript 模型

2. 基於C++的PyTorch模型部署

3. torch.jit.trace

4. torch.jit.script

5. 使用C++調用並部署pytorch模型

6. libtorch c++部署-使用GPU


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM