野路子碼農系列(9)利用ONNX加速Pytorch模型推斷


最近在做一個文本多分類的模型,非常常規的BERT+finetune的套路,考慮到運行成本,打算GPU訓練后用CPU做推斷。

在小破本上試了試,發現推斷速度異常感人,尤其是序列長度增加之后,一條4-5秒不是夢。

於是只能尋找加速手段,早先聽過很多人提到過ONNX,但從來沒試過,於是就學習了一下,發現效果還挺不錯的,手法其實也很簡單,就是有幾個小坑。

 

第1步 - 保存模型

首先得從torch中將模型導出成ONNX格式,可以在cross-validation的eval階段進行這一步驟:

def eval_fn(data_loader, model, device):
    '此處省略其他代碼'
    
    onnx_path = 'inference_model.onnx' # 指定保存路徑
    torch.onnx._export(
        model, # BERT fintune model (instance)
        (ids, mask, token_type_ids), # model的輸入參數,裝入tuple
        onnx_path, # 保存路徑
        opset_version=10, # 此處有坑,必須指定≥10,否則會報錯
        do_constant_folding=True,
        input_names=['ids', 'mask', 'token_type_ids'], # model輸入參數的名稱
        output_names=['output'],
        export_params=True,
        dynamic_axes={
            'ids': {0: 'batch_size', 1: 'seq_length'}, # 0, 1分別代表axis 0和axis 1
            'mask': {0: 'batch_size', 1: 'seq_length'},
            'token_type_ids': {0: 'batch_size', 1: 'seq_length'},
            'output': {0: 'batch_size', 1: 'seq_length'}
        } # 用於變長序列(比如dynamic padding)和可能改變batch size的情況
    )
    
    
    return '此處省略返回值'

  

這里需要注意的幾個點:

  •  torch自帶了導出ONNX的方法,直接用就行
  • 你的模型可以有1個輸入參數,也可以有多個,如果有多個,得裝在tuple里
  • 相應的input_names要與你的參數一一對應,放在list里
  • opset_version建議設成10,默認不設的話可能會報錯(ONNX export of Slice with dynamic inputs)
  • 如果你在data loader里設置了collate func來進行dynamic padding的話(不同batch的文本長度可能不一樣),一定要設置dynamic_axes,否則之后加載推斷時會出錯(因為它會要求你推斷時輸入的各個維度與你保存ONNX模型時的輸入緯度完全一致)。

 

第2步 - 加載模型與推斷

接下來是推斷環節,首先別忘了用 pip install onnxpip install onnxruntime 來安裝必需的庫,之后通過以下代碼導入使用:

import onnxruntime as ort

接下來你可以照常寫你的dataset和data loader,但需要注意的是,data loader返回的得是numpy.array,而不是torch.tensor(collate_fn里改改就行),否則報錯伺候。

然后就是導入模型:

import onnxruntime as ort

onnx_model_path = 'inference_model.onnx' 
session = ort.InferenceSession(onnx_model_path)

再把data loader的輸出分別接入對應的三個參數就好了:

session.run(ids, mask, token_type_ids)

 

%%timeit看一下運行時間(CPU):

4條長度為10的文本

torch:4.77s

torch+ONNX:39.7ms

 

4條長度為50的文本

torch:21.2s

torch+ONNX:246ms

差不多快了百倍有余,效果相當不錯啦。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM