tvm編譯tensorflow模型


實現官方demo並且將轉化后的tvm模型進行保存,重新讀取和推理

在jupyter notebook上操作的,代碼比較分散,其他編譯器將代碼全部拼起來編譯就ok了

官方文檔

https://tvm.apache.org/docs/tutorials/frontend/from_tensorflow.html#sphx-glr-tutorials-frontend-from-tensorflow-py

1.導入頭文件

# tvm, relay
import tvm
from tvm import te
from tvm import relay

# os and numpy
import numpy as np
import os.path

# Tensorflow imports
import tensorflow as tf

try:
    tf_compat_v1 = tf.compat.v1
except ImportError:
    tf_compat_v1 = tf

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
from tensorflow.keras.datasets import mnist
from tensorflow.python.platform import gfile

2.設置下載路徑與配置文件參數(cpu支持的模型,由llvm編譯)

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'

# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)

######################################################################
# Tutorials
# ---------
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.

model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)

# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)

# Human readable text for labels
label_map = 'imagenet_synset_to_human_label_map.txt'
label_map_url = os.path.join(repo_base, label_map)

# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.context(target, 0)#tvm.cpu(0)

3.下載需要資源,如果下載失敗,根據錯誤提示手動下載資源放到對應目錄

from tvm.contrib.download import download_testdata  

img_path = download_testdata(image_url, img_name, module='data')  
model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])  
map_proto_path = download_testdata(map_proto_url, map_proto, module='data')  
label_path = download_testdata(label_map_url, label_map, module='data') 
print(model_path)

 

 4.讀入模型

with tf_compat_v1.gfile.FastGFile(model_path, 'rb') as f:
    graph_def = tf_compat_v1.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')

 

 5.處理訓練數據和讀入模型

from PIL import Image
image = Image.open(img_path).resize((299, 299))

x = np.array(image)

######################################################################
# Import the graph to Relay
# -------------------------
# Import tensorflow graph definition to relay frontend.
#
# Results:
#   sym: relay expr for given tensorflow protobuf.
#   params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict)

print("Tensorflow protobuf imported to relay frontend.")
print(mod.astext(show_meta_data=False))

會有warning,問題不大

 

 6.開始編譯

with relay.build_config(opt_level=3):
    lib = relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)

也是一堆warning,問題不大

 

7. 進行推理

from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.GraphModule(lib["default"](ctx))
# set inputs
m.set_input("DecodeJpeg/contents", tvm.nd.array(x.astype(dtype)))
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                    uid_lookup_path=label_path)

# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))

可以正常推理,轉化成功

 

8.保存

from tvm.contrib import utils
temp=utils.tempdir()
path_lib=temp.relpath("/home/aiteam/tiwang/tvm_code/inceptionV2.1_lib.tar")
lib.export_library(path_lib)

 保存成功

 

9.讀入模型並進行推理

loaded_lib=tvm.runtime.load_module(path_lib)
input_data=tvm.nd.array(x.astype(dtype))

mm=graph_runtime.GraphModule(loaded_lib["default"](ctx))
mm.run(data=input_data)
out_deploy = mm.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

predictions = out_deploy.asnumpy()
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                    uid_lookup_path=label_path)

# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-10:][::-1]
for node_id in top_k:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))

 

 整個流程ok

 

但是在轉換自己模型的時候還是有很多問題,第一個是從tf-serving中拿出來的demo模型在第4步會出現解碼錯誤

參照

https://stackoverflow.com/questions/61883290/to-load-pb-file-decodeerror-error-parsing-message

進行修改

with tf_compat_v1.gfile.FastGFile(model_dir+model_name, 'rb') as f:
    data=compat.as_bytes(f.read())
    graph_def=saved_model_pb2.SavedModel()
    graph_def.ParseFromString(data)
    graph_def=graph_def.meta_graphs[0].graph_def
    
    #graph_def = tf_compat_v1.GraphDef()
    #graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')

可以正常解碼,但是后面還會出現問題

AssertionError: softmax is not in graph

這里還需要看一下tensorflow源碼,學一下graph和def_graph

另外我自己又寫了一個mnist模型,是可以正常讀取的,只不過在softmax這里仍會出現同樣錯誤,根據其他博客的說法我覺得第一個錯誤是模型保存時調用函數的問題,倒是保存時文件格式一樣但是編碼卻不同,因為stackoverflow里面的解決方案

compat.as_bytes(f.read())

明顯是轉換了格式


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM