如何修改已有的ONNX模型


簡單來說,我們只需要學習一下把大象如何放進冰箱的就行了:

1、把冰箱門打開

使用onnx的原生接口:

onnx_model = onnx.load(onnx_path) graph = onnx_model.graph

這樣我們就可以將模型load出來,並且到到graph信息。

2、把大象放進去

這一步相對來說選擇就比較多了,比如你可以選擇刪除一些節點,修改一下節點,增加一些節點。

刪除:這個是最容易的,直接一句話 graph.node.remove(xxx_node)

修改:舉個例修改一下input的名稱

for input_node in onnx_model.graph.input:
    if 'input_xxx' == input_node.name:
        print("change input data name")
        input_node.name = 'data'

就是拿到某個屬性或者信息,然后直接修改就行了。

增加:舉個例增加一組圖像預處理操作(減均值,除方差)

這一步稍微復雜一點,我們首先要生成一個node或者tensor,然后將這個node或者tensor加入graph中,然后選擇性的增加一個node來操作剛剛加入graph的node或者tensor。

首先我們生成一個tensor,就是需要減去的均值

sub_const_node = onnx.helper.make_tensor(name='const_sub',
                      data_type=onnx.TensorProto.FLOAT,
                      dims=[1],
                      vals=[-127.5])

然后我們將剛剛生成的tensor插入graph中

graph.initializer.append(sub_const_node)

然后我們再增加一個減均值的node

sub_node = onnx.helper.make_node(
                'Add',
                name='pre_sub',
                inputs=['data', 'const_sub'],
                outputs=['pre_sub'])

然后將node加入graph中

graph.node.insert(0, sub_node)

仿造這樣的流程我們繼續加入除以方差的操作

# 插入mul
mul_const_node = onnx.helper.make_tensor(name='const_mul',
                      data_type=onnx.TensorProto.FLOAT,
                      dims=[1],
                      vals=[1.0 / 127.5])
 
graph.initializer.append(mul_const_node)

sub_node = onnx.helper.make_node(
               'Mul',
               name='pre_mul',
               inputs=['pre_sub', 'const_mul'],
               outputs=['pre_mul'])
graph.node.insert(1, sub_node)

這樣操作之后,我們還需要一步,就是將第一個卷積層的輸入改動一下:

# 第一層卷積的輸入修改
 for id, node in enumerate(graph.node):
     for i, input_node in enumerate(node.input):
         if 'data' == input_node:
             node.input[i] = 'pre_mul'

這樣能我們加入node或者tensor的過程基本就結束了

3、把冰箱門關上

這一步我們就可以簡單的重組一下graph,然后save模型就行了

graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
info_model = onnx.helper.make_model(graph)
onnx_model = onnx.shape_inference.infer_shapes(info_model)
 
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, onnx_path.replace('nopre', 'fix'))

完事我們就拿到一個修改過的模型了,找一個人臉的模型來示例一下:

紅圈就是增加的節點,當然你也可以增加其它節點。注意按照onnx的op要求實現就行了,給一個參考的路徑:

出處:如何修改已有的ONNX模型 - 知乎 (zhihu.com)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2026 CODEPRJ.COM