1. 從 ckpt-.data,ckpt-.index 和 .meta 生成 frozenpb
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路徑
:return:
'''
# 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
output_node_names = "outputs"
saver = tf.train.import_meta_graph(os.path.join(os.path.split(input_checkpoint)[0], 'graph.meta'), clear_devices=True)
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢復圖並得到數據
output_graph_def = graph_util.convert_variables_to_constants(
# 模型持久化,將變量值固定
sess=sess,
input_graph_def=sess.graph_def,# 等於:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗號隔開
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化輸出
print("%d ops in the final graph." % len(output_graph_def.node))
#得到當前圖有幾個操作節點
if __name__ == "__main__":
# 輸入ckpt模型路徑
input_checkpoint='ckpt_path/ckpt-10000'
# 輸出pb模型的路徑
out_pb_path="some_path/frozen_model.pb"
# 調用freeze_graph將ckpt轉為pb
freeze_graph(input_checkpoint,out_pb_path)
2. 從網絡代碼和 ckpt-.data 文件生成 frozenpb
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
import network # 導入網絡結構
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 設置GPU
model_path = "ckpt_path/ckpt-10000"
def main():
tf.reset_default_graph()
input_node = tf.placeholder(
tf.float32, shape=(None,112, 96, 3)
)
input_node = tf.identity(input_node,name="inputs") # 設置輸入節點的名字,這里可以自定義名稱
flow = network(input_node)
flow = tf.identity(flow, name="outs") # 設置輸出類型以及輸出的接口名字,為了之后的調用pb的時候使用
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
# 保存圖
tf.train.write_graph(sess.graph_def, "logdir/", "graph.pb")
# 把圖和參數結構一起
freeze_graph.freeze_graph(
"logdir/graph.pb", # 上面保存的圖結構 graph.pb
"",
False,
model_path,
"outs",
"save/restore_all", # 默認恢復所有
"save/Const:0", # 默認常量
"some_path/frozen.pb", # 保存frozen.pb
False,
"",
)
print("done")
if __name__ == "__main__":
main()
3. 打印 網絡中節點的名字
import tensorflow as tf
if __name__ == "__main__":
checkpoint_path = '../model_fintune/ckpt-1400'
reader = tf.train.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor name: ", key)
# print(reader.get_tensor(key))
或者通過
import tensorflow as tf
def printTensors(pb_file):
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
print(op.name)
printTensors("path-to-my-pbfile.pb")
4. 兩種方法對比
如果是自己的代碼訓練的模型,有網絡結構,有 ckpt 文件,最好是使用第二種方法,使用起來很靈活,可以進行各種自定義,比如修改輸入輸出的節點名字,網絡有多個路徑的時候可以自定義輸出路徑。第一種方法,應該也能達到第二種方法的效果,因為它們本來就是等價的,可能會有些麻煩。第一種方法的好處就是快,不要去翻那些雜糅在一起的網絡結構。