通常我們在訓練模型時可以使用很多不同的框架,比如有的同學喜歡用 Pytorch,有的同學喜歡使用 TensorFLow,也有的喜歡 MXNet,以及深度學習最開始流行的 Caffe等等,這樣不同的訓練框架就導致了產生不同的模型結果包,在模型進行部署推理時就需要不同的依賴庫,而且同一個框架比如tensorflow 不同的版本之間的差異較大, 為了解決這個混亂問題,LF AI 這個組織聯合 Facebook, MicroSoft等公司制定了機器學習模型的標准,這個標准叫做ONNX, Open Neural Network Exchage,所有其他框架產生的模型包 (.pth, .pb) 都可以轉換成這個標准格式,轉換成這個標准格式后,就可以使用統一的 ONNX Runtime等工具進行統一部署。
這其實可以和 JVM 對比,
A Java virtual machine (JVM) is a virtual machine that enables a computer to run Java programs as well as programs written in other languages that are also compiled to Java bytecode. The JVM is detailed by a specification that formally describes what is required in a JVM implementation. Having a specification ensures interoperability of Java programs across different implementations so that program authors using the Java Development Kit (JDK) need not worry about idiosyncrasies of the underlying hardware platform.
JAVA中有 JAVA 語言 + .jar 包 + JVM,同時還有其他的語言比如 Scala等也是建立在 JVM上運行的,因此不同的語言只要都最后將程序轉換成 JVM可以統一識別的格式,就可以在統一的跨平台 JVM JAVA 虛擬機上運行。這里JVM使用的 包是二進制包,因此里面的內容是不可知的,人類難以直觀理解的。
這里 ONNX 標准采取了谷歌開發 protocal buffers 作為格式標准,這個格式是在 XML, json的基礎上發展的,是一個人類易理解的格式。ONNX 官網對ONNX的介紹如下:
ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.
ONNX支持的模型來源,基本上囊括了我們日常使用的所有框架:

ONNX的文件格式,采用的是谷歌的 protocal buffers,和 caffe采用的一致。

ONNX定義的數據類包括了我們常用的數據類型,用來定義模型中的輸出輸出格式

ONNX中定義了很多我們常用的節點,比如 Conv,ReLU,BN, maxpool等等約124種,同時也在不停地更新中,當遇到自帶節點庫中沒有的節點時,我們也可以自己寫一個節點

- 有了輸入輸出,以及計算節點,就可以根據 pytorch框架中的
forward記錄一張模型從輸入圖片到輸出的計算圖,ONNX 就是將這張計算圖用標准的格式存儲下來了,可以通過一個工具Netron對 ONNX 進行可視化,如第一張圖右側所示; - 保存成統一的 ONNX 格式后,就可以使用統一的運行平台來進行 inference。
pytorch原生支持 ONNX 格式轉碼,下面是實例:
1. 將pytorch模型轉換為onnx格式,直接傻瓜式調用 torch.onnx.export(model, input, output_name)
import torch
from torchvision import models
net = models.resnet.resnet18(pretrained=True)
dummpy_input = torch.randn(1,3,224,224)
torch.onnx.export(net, dummpy_input, 'resnet18.onnx')
2. 對生成的 onnx 進行查看
import onnx
# Load the ONNX model
model = onnx.load("resnet18.onnx")
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
輸出:
可以看到其輸出有3個dict,一個是 input, 一個是 initializers,以及最后一個是operators把輸入和權重 initialization 進行類似於 forward操作,在最后一個dict operators中其返回是 %191,也就是 gemm 最后一個全連接的輸出
graph torch-jit-export (
%input.1[FLOAT, 1x3x224x224]
) initializers (
%193[FLOAT, 64x3x7x7]
%194[FLOAT, 64]
%196[FLOAT, 64x64x3x3]
%197[FLOAT, 64]
%199[FLOAT, 64x64x3x3]
%200[FLOAT, 64]
%202[FLOAT, 64x64x3x3]
%203[FLOAT, 64]
%205[FLOAT, 64x64x3x3]
%206[FLOAT, 64]
%208[FLOAT, 128x64x3x3]
%209[FLOAT, 128]
%211[FLOAT, 128x128x3x3]
%212[FLOAT, 128]
%214[FLOAT, 128x64x1x1]
%215[FLOAT, 128]
%217[FLOAT, 128x128x3x3]
%218[FLOAT, 128]
%220[FLOAT, 128x128x3x3]
%221[FLOAT, 128]
%223[FLOAT, 256x128x3x3]
%224[FLOAT, 256]
%226[FLOAT, 256x256x3x3]
%227[FLOAT, 256]
%229[FLOAT, 256x128x1x1]
%230[FLOAT, 256]
%232[FLOAT, 256x256x3x3]
%233[FLOAT, 256]
%235[FLOAT, 256x256x3x3]
%236[FLOAT, 256]
%238[FLOAT, 512x256x3x3]
%239[FLOAT, 512]
%241[FLOAT, 512x512x3x3]
%242[FLOAT, 512]
%244[FLOAT, 512x256x1x1]
%245[FLOAT, 512]
%247[FLOAT, 512x512x3x3]
%248[FLOAT, 512]
%250[FLOAT, 512x512x3x3]
%251[FLOAT, 512]
%fc.bias[FLOAT, 1000]
%fc.weight[FLOAT, 1000x512]
) {
%192 = Conv[dilations = [1, 1], group = 1, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]](%input.1, %193, %194)
%125 = Relu(%192)
%126 = MaxPool[kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%125)
%195 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%126, %196, %197)
%129 = Relu(%195)
%198 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%129, %199, %200)
%132 = Add(%198, %126)
%133 = Relu(%132)
%201 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%133, %202, %203)
%136 = Relu(%201)
%204 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%136, %205, %206)
%139 = Add(%204, %133)
%140 = Relu(%139)
%207 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%140, %208, %209)
%143 = Relu(%207)
%210 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%143, %211, %212)
%213 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%140, %214, %215)
%148 = Add(%210, %213)
%149 = Relu(%148)
%216 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%149, %217, %218)
%152 = Relu(%216)
%219 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%152, %220, %221)
%155 = Add(%219, %149)
%156 = Relu(%155)
%222 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%156, %223, %224)
%159 = Relu(%222)
%225 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%159, %226, %227)
%228 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%156, %229, %230)
%164 = Add(%225, %228)
%165 = Relu(%164)
%231 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%165, %232, %233)
%168 = Relu(%231)
%234 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%168, %235, %236)
%171 = Add(%234, %165)
%172 = Relu(%171)
%237 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%172, %238, %239)
%175 = Relu(%237)
%240 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%175, %241, %242)
%243 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%172, %244, %245)
%180 = Add(%240, %243)
%181 = Relu(%180)
%246 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%181, %247, %248)
%184 = Relu(%246)
%249 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%184, %250, %251)
%187 = Add(%249, %181)
%188 = Relu(%187)
%189 = GlobalAveragePool(%188)
%190 = Flatten[axis = 1](%189)
%191 = Gemm[alpha = 1, beta = 1, transB = 1](%190, %fc.weight, %fc.bias)
return %191
}
3. 對生成的ONNX進行可視化:
onnx的可是支持有兩個,一個是 netron, 一個是百度飛槳開發的visualDL
這里介紹 netron的下載安裝:https://github.com/lutzroeder/Netron,對於 mac用戶可以安裝成功直接打開軟件進行圖形化選取onnx地址就可以打開


4. ONNX Runtime
支持ONNX的runtime就是類似於JVM將統一的ONNX格式的模型包運行起來,包括對ONNX 模型進行解讀,優化(融合conv-bn等操作),運行。

這里介紹 microsoft 開發的 ONNX Runtime
4.1 ONNXRuntime的安裝
https://github.com/microsoft/onnxruntime
對於使用cpu來進行推理的 mac os 可以使用
brew install libomp
pip install onnxruntime
推理
import onnxruntime as rt
import numpy as np
data = np.array(np.random.randn(1,3,224,224))
sess = rt.InferenceSession('resnet18.onnx')
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name:data.astype(np.float32)})[0]
print(pred_onx)
print(np.argmax(pred_onx)
可以看到,這樣推理就不需要其他各種各樣的pytorch等依賴,方便部署。
推薦兩個易懂的視頻講解:
Everything You Want to Know About ONNX
MicroSoft onnx and onnx runtim
