最近在做一個文本多分類的模型,非常常規的BERT+finetune的套路,考慮到運行成本,打算GPU訓練后用CPU做推斷。
在小破本上試了試,發現推斷速度異常感人,尤其是序列長度增加之后,一條4-5秒不是夢。
於是只能尋找加速手段,早先聽過很多人提到過ONNX,但從來沒試過,於是就學習了一下,發現效果還挺不錯的,手法其實也很簡單,就是有幾個小坑。
第1步 - 保存模型
首先得從torch中將模型導出成ONNX格式,可以在cross-validation的eval階段進行這一步驟:
def eval_fn(data_loader, model, device): '此處省略其他代碼' onnx_path = 'inference_model.onnx' # 指定保存路徑 torch.onnx._export( model, # BERT fintune model (instance) (ids, mask, token_type_ids), # model的輸入參數,裝入tuple onnx_path, # 保存路徑 opset_version=10, # 此處有坑,必須指定≥10,否則會報錯 do_constant_folding=True, input_names=['ids', 'mask', 'token_type_ids'], # model輸入參數的名稱 output_names=['output'], export_params=True, dynamic_axes={ 'ids': {0: 'batch_size', 1: 'seq_length'}, # 0, 1分別代表axis 0和axis 1 'mask': {0: 'batch_size', 1: 'seq_length'}, 'token_type_ids': {0: 'batch_size', 1: 'seq_length'}, 'output': {0: 'batch_size', 1: 'seq_length'} } # 用於變長序列(比如dynamic padding)和可能改變batch size的情況 ) return '此處省略返回值'
這里需要注意的幾個點:
- torch自帶了導出ONNX的方法,直接用就行
- 你的模型可以有1個輸入參數,也可以有多個,如果有多個,得裝在tuple里
- 相應的input_names要與你的參數一一對應,放在list里
- opset_version建議設成10,默認不設的話可能會報錯(ONNX export of Slice with dynamic inputs)
- 如果你在data loader里設置了collate func來進行dynamic padding的話(不同batch的文本長度可能不一樣),一定要設置dynamic_axes,否則之后加載推斷時會出錯(因為它會要求你推斷時輸入的各個維度與你保存ONNX模型時的輸入緯度完全一致)。
第2步 - 加載模型與推斷
接下來是推斷環節,首先別忘了用 pip install onnx 和 pip install onnxruntime 來安裝必需的庫,之后通過以下代碼導入使用:
import onnxruntime as ort
接下來你可以照常寫你的dataset和data loader,但需要注意的是,data loader返回的得是numpy.array,而不是torch.tensor(collate_fn里改改就行),否則報錯伺候。
然后就是導入模型:
import onnxruntime as ort onnx_model_path = 'inference_model.onnx' session = ort.InferenceSession(onnx_model_path)
再把data loader的輸出分別接入對應的三個參數就好了:
session.run(ids, mask, token_type_ids)
用%%timeit看一下運行時間(CPU):
4條長度為10的文本
torch:4.77s
torch+ONNX:39.7ms
4條長度為50的文本
torch:21.2s
torch+ONNX:246ms
差不多快了百倍有余,效果相當不錯啦。