深度學習模型轉換,以pytorch轉tensorflow為例


這里以onnx為中介進行轉換。主要用到

STEP1. 將pytorch 模型轉換成onnx模型

注意這里關鍵是要構造一個模型的輸入輸入,這里假設模型接受兩個輸入。

pmodel = PytorchModel()
dummy_input = (np.zeros((1, 30), dtype=np.float32), np.zeros((1, 2), dtype=np.float32))
torch.onnx.export(pmodel, (torch.as_tensor(dummy_input[0]), torch.as_tensor(dummy_input[1])), "/tmp/xx.onnx",
                  verbose=True, input_names=['input1', 'input2'], output_names=['output1', 'output2'])

參數 input_names表示模型的輸入參數(隨便起名字),output_names表示輸出名字

STEP 2. 將onnx模型轉成tf

這里需要借助onnx_tf這個庫

import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load("/tmp/xx.onnx")  # load onnx model
tf_model = prepare(onnx_model)
tf_model.export_graph("/tmp/xxpb/")  # export the model

STEP 3 使用tensorflow模型

import tensorflow as tf
import io
import numpy as np

model_path = '/tmp/xxpb/'

sess = tf.compat.v1.Session()
metagraph = tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], model_path)
sig = metagraph.signature_def["serving_default"]
input_dict = dict(sig.inputs)
output_dict = dict(sig.outputs)
print(input_dict, output_dict)
output_stochastic_act_label_0 = output_dict["output_0"].name
output_stochastic_act_label_1 = output_dict["output_1"].name

input_state_label = None
initial_state = None
state = None
if "state" in input_dict.keys():
    input_state_label = input_dict["state"].name
    strfile = io.StringIO()
    print(input_dict["state"].tensor_shape, file=strfile)
    lines = strfile.getvalue().split("\n")
    dim_1 = int(lines[1].split(":")[1].strip(" "))
    dim_2 = int(lines[4].split(":")[1].strip(" "))
    initial_state = np.zeros((dim_1, dim_2), dtype=np.float32)
    state = np.zeros((dim_1, dim_2), dtype=np.float32)
input_obs_label_1 = input_dict["input1"].name
input_obs_label_0 = input_dict["input2"].name
input_dict = {input_obs_label_0: np.zeros((1, 2), dtype=np.float32), input_obs_label_1:np.zeros((1, 30), dtype=np.float32)}
out = sess.run((output_stochastic_act_label_0, output_stochastic_act_label_1), feed_dict=input_dict)
print(out)

注意這里的name需要重新設置一遍。






免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM