PyTorch/TensorFlow自定義OP導出ONNX


PyTorch

根據PyTorch的官方文檔,需要用Function封裝一下,為了能夠導出ONNX需要加一個symbolic靜態方法:

class relu5_func(Function):
    @staticmethod
    def forward(ctx, input):
        return relu5_cuda.relu5(input)
    @staticmethod
    def symbolic(g, *inputs):
        return g.op("Relu5", inputs[0], myattr_f=1.0) 
        # 這里第一個參數"Relu5"表示ONNX輸出命名
        # myattr可以隨便取,表示一個屬性名,_f表示是一個float類型
relu5 = relu5_func.apply

定義好后,用以下代碼測試

import torch
import torch.nn as nn
import relu5_cuda
import onnx
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import netron

class TinyNet(nn.Module):
    def __init__(self):
        super(TinyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = x.view(-1)
        x = relu5(x)
        return x

net = TinyNet().cuda()
ipt = torch.ones(2,3,12,12).cuda()
torch.onnx.export(net, (ipt,), 'tinynet.onnx')
print(onnx.load('tinynet.onnx'))
netron.start('tinynet.onnx')

TensorFlow

導出pb文件

import tensorflow as tf 
from tensorflow.python.framework import graph_util

conv1_w = tf.Variable(tf.random_normal([3, 3, 2, 3]))
conv1_b = tf.Variable(tf.random_normal([3]))
conv2_w = tf.Variable(tf.random_normal([3, 3, 3, 1]))
conv2_b = tf.Variable(tf.random_normal([1]))
xs = tf.placeholder(tf.float32, shape=[1, 12, 12, 2], name="input")
conv1 = tf.nn.conv2d(xs, conv1_w, strides=[1,1,1,1], padding='SAME') + conv1_b
conv2 = tf.nn.conv2d(conv1, conv2_w, strides=[1,1,1,1], padding='SAME') + conv2_b
tf.identity(conv2, name='output')

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    # sess.run(conv2, feed_dict={xs: x})
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
    with tf.gfile.FastGFile('tfmodel.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())
netron.start('tfmodel.pb')

轉化需要

pip3 install tf2onnx

以下參數中X:0和output:0必須是一個字符串加冒號加數字形式

python3 -m tf2onnx.convert \
--input tfmodel.pb \
--inputs X:0 \
--output tfmodel.onnx \
--outputs output:0

或者使用Python代碼

import tensorflow as tf 
import tf2onnx
from tf2onnx import loader

# graph
conv1_w = tf.Variable(tf.random_normal([3, 3, 2, 3]))
conv1_b = tf.Variable(tf.random_normal([3]))
conv2_w = tf.Variable(tf.random_normal([3, 3, 3, 1]))
conv2_b = tf.Variable(tf.random_normal([1]))
xs = tf.placeholder(tf.float32, shape=[1, 12, 12, 2], name="input")
conv1 = tf.nn.conv2d(xs, conv1_w, strides=[1,1,1,1], padding='SAME') + conv1_b
conv2 = tf.nn.conv2d(conv1, conv2_w, strides=[1,1,1,1], padding='SAME') + conv2_b
tf.identity(conv2, name='output')
# get output_graph_def
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    output_graph_def = loader.freeze_session(sess, output_names=["output:0"])
# to onnx
tf.reset_default_graph()
with tf.Graph().as_default() as tf_graph:
    tf.import_graph_def(output_graph_def, name='')
    onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph, input_names=["input:0"], output_names=["output:0"], opset=11)
    model_proto = onnx_graph.make_model("test")
    with open("tfmodel.onnx", "wb") as f:
        f.write(model_proto.SerializeToString())
# show
import onnx 
import netron
print(onnx.load('tfmodel.onnx'))
netron.start('tfmodel.onnx')


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM