ckpt,pb,tflite轉換
一、ckpt,pb,tflite文件及其特點
ckpt模型文件
ckpt是tensorflow的默認的模型保存讀取文件,包含四個部分:
- checkpoint
- model.ckpt.meta
- model.ckpt.index
- model.ckpt.data*
是結構權重數據分離的四個文件,其中
checkpoint:
記錄模型目錄下所有模型的文件列表
*ckpt.meta:
保存tensorflow計算圖的網絡結構
*ckpt.index:
保存了當前參數名
*ckpt.data:
保存了當前參數值
pb模型文件
pb模型是graph_def
的序列化文件,固化參數,只能用來做前向預測。(雖然如此,也能很容易的獲得模型結構,重新復現也會容易很多)
tflite文件
tf-lite主要是針對移動端進行優化的平台,重新定義了移動端的核心算子,也提供了硬件加速的接口,擁有新的優化解釋器。
二、模型保存和恢復
ckpt模型保存與恢復
# 參數恢復
saver_restore = tf.train.Saver([var for var in tf.trainable_variables()])
saver_restore.restore(sess, ckpt.model_checkpoint_path)
# 參數保存
saver = tf.train.Saver(max_to_keep=10)
saver.save(sess, "model.ckpt")
pb模型加載
通過tensor_name
獲取節點:get_tensor_by_name()
# 讀文件到graph_def
with tf.gfile.GFile(pb_path, 'rb') as fgraph:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fgraph.read())
# print(graph_def)
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='') # 把graph_def 加載到default_graph
# 使用get_tensor_by_name獲取tensor
input_tensor = graph.get_tensor_by_name('VIDEOSR/Slice:0')
output_tensor = graph.get_tensor_by_name('%s:0' % out_node_name)
# 使用sess.run執行
image_out = sess.run(output_tensor, feed_dict={input_tensor: image_in})
...
tf-lite模型加載
通過index
獲取節點:set_tensor(),get_tensor()
def run_example_single(model_path,input_image,feature2,feature1):
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path) # "model/save/converted_model.tflite"
interpreter.allocate_tensors()
# get input output info
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
# inputs index
index_inImg = input_details[0]['index']
# outputs index
index_outImg = output_details[0]['index']
# set inputs
interpreter.set_tensor(index_inImg, input_image)
# invoke
interpreter.invoke()
# get results
outImg = interpreter.get_tensor(index_outImg)
return outImg
三、ckpt,pb,tf-lite之間的轉換
ckpt轉pb
ckpt轉pb是模型的持久化,固化參數的結果,一般只做前向。可以參考官方代碼``
流程:
- 加載ckpt模型
- 將圖使用
tf.train.write_graph()
寫出 - 使用
freeze_graph.freeze_graph()
把模型參數固化保存
import tensorflow as tf
import os
import slim.nets.mobilenet_v1 as mobilenet_v1
import tensorflow.contrib.slim as slim
from tensorflow.python.tools import freeze_graph
def export_eval_pbtxt(MODEL_SAVE_PATH):
"""Export eval.pbtxt."""
with tf.Graph().as_default() as g:
images = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='input')
# is_training=False會把BN層去掉
with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=False, regularize_depthwise=True)):
_, _ = mobilenet_v1.mobilenet_v1(inputs=images, is_training=False, depth_multiplier=1.0, num_classes=7)
saver = tf.train.Saver(max_to_keep=5)
pb_dir = os.path.join(MODEL_SAVE_PATH, 'pb_model')
graph_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'mobilenet_v1_eval.pbtxt')
checkpoint = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
frozen_model = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
with tf.Session() as sess:
if checkpoint and checkpoint.model_checkpoint_path:
try:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
except:
print("Error on loading old network weights")
else:
print("Could not find old network weights")
print('Learning Started!')
with open(graph_file, 'w') as f:
f.write(str(g.as_graph_def()))
freeze_graph.freeze_graph(graph_file,
'',
False,
checkpoint.model_checkpoint_path,
"MobilenetV1/Predictions/Softmax",
'save/restore_all',
'save/Const:0',
frozen_model,
True,
"")
pb模型轉tflite模型
- 將pb模型加載
tf.lite.TFLiteConverter.from_frozen_graph()
- 對模型進行轉換
converter.convert()
- 將轉換 后的結果保存在文件
def pb_to_tflite(input_name, output_name):
graph_def_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
input_arrays = [input_name]
output_arrays = [output_name]
converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
tflite_file = os.path.join(MODEL_SAVE_PATH, 'tflite_model', 'converted_model.tflite')
open(tflite_file, "wb").write(tflite_model)