近期由於業務需要,需要將訓練好的模型轉為ONNX格式,為此頗費了一番功夫,在此總結一下吧。。
1、ONNX是一種神經網絡模型保存的中間格式,支持多種格式的模型轉為ONNX,也支持使用ONNX導入多種格式的模型,具體見https://github.com/onnx/tutorials;目前其實ONNX對於模型的支持還不是太好,主要表現在一些op還不能夠支持;
2、在PyTorch下要將模型保存成ONNX格式需要使用torch.onnx.export()函數,使用該函數的時候需要傳入下面參數:
--model:待保存的model,也就是你在程序中已經訓練好或者初始化好的模型
--input_shape:指定輸入數據的大小,也就是輸入數據的形狀,是一個包含輸入形狀元組的列表;
--name:模型的名稱,即模型的保存路徑;
--verbrose:True或者False,用來指定輸出模型時是否將模型的結構打印出來;
--input_names:輸入數據節點的名稱,數據類型為包含字符串的列表;一般將這個名稱設為['data'];
--output_names:輸出數據節點的名稱,類型與輸入數據的節點名稱相同;
在成功導出模型后,可以使用ONNX再對模型進行檢查:
import onnx # Load the ONNX model
model = onnx.load("alexnet.onnx") # Check that the IR is well formed
onnx.checker.check_model(model) # Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)
目前PyTorch還不支持導入ONNX格式的模型。
3、使用MXNET導出模型為ONNX時,參考地址:https://cwiki.apache.org/confluence/display/MXNET/ONNX,http://mxnet.incubator.apache.org/versions/master/tutorials/onnx/export_mxnet_to_onnx.html。MXNet模型的保存格式為.json文件+.params文件,.json文件里保存的是模型的結構,.params文件中保存的是模型的參數。使用onnx_mxnet.export_export_model()方法就可以實現將模型從mxnet轉為ONNX格式,該方法需要傳入的參數為:
--sym:.json文件,也就是保存了網絡結構的文件
--params:參數文件
--input_shape:輸入數據的形狀,是一個包含形狀元組的列表
--input_type:輸入數據的類型;
--模型的保存路徑
4、從MXNet導入ONNX格式模型:需要使用mxnet.contrib.onnx.onnx2mx.import_model.
import_model
(model_file),這里返回的是sym, arg_arams,aux_params,也就是網絡結構symbol對象,保存參數的字典, 再將其轉為MXNet的module對象(使用mxnet.module.Module()),即可將模型恢復到mxnet框架下可執行的模型。
最后,好久沒有記錄日常學習積累的東西了,趁着失眠開個好頭吧,晚安。。。