兩種從 TensorFlow 的 checkpoint生成 frozenpb 的方法


1. 從 ckpt-.data,ckpt-.index 和 .meta 生成 frozenpb

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util


def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路徑
    :return:
    '''
    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    output_node_names = "outputs"
    saver = tf.train.import_meta_graph(os.path.join(os.path.split(input_checkpoint)[0], 'graph.meta'), clear_devices=True)
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢復圖並得到數據
        output_graph_def = graph_util.convert_variables_to_constants(  
            # 模型持久化,將變量值固定
            sess=sess,
            input_graph_def=sess.graph_def,# 等於:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗號隔開
 
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化輸出
        print("%d ops in the final graph." % len(output_graph_def.node)) 
        #得到當前圖有幾個操作節點

if __name__ == "__main__":
    # 輸入ckpt模型路徑
    input_checkpoint='ckpt_path/ckpt-10000'
    # 輸出pb模型的路徑
    out_pb_path="some_path/frozen_model.pb"
    # 調用freeze_graph將ckpt轉為pb
    freeze_graph(input_checkpoint,out_pb_path)

2. 從網絡代碼和 ckpt-.data 文件生成 frozenpb

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph

import network  # 導入網絡結構

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 設置GPU
model_path = "ckpt_path/ckpt-10000"

def main():
    tf.reset_default_graph()
    input_node = tf.placeholder(
        tf.float32, shape=(None,112, 96, 3)
    ) 
    input_node = tf.identity(input_node,name="inputs") # 設置輸入節點的名字,這里可以自定義名稱
    flow = network(input_node)
    flow = tf.identity(flow, name="outs") # 設置輸出類型以及輸出的接口名字,為了之后的調用pb的時候使用
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, model_path)
        # 保存圖
        tf.train.write_graph(sess.graph_def, "logdir/", "graph.pb")
        # 把圖和參數結構一起
        freeze_graph.freeze_graph(
            "logdir/graph.pb", # 上面保存的圖結構 graph.pb
            "",
            False,
            model_path,
            "outs",
            "save/restore_all", # 默認恢復所有
            "save/Const:0", # 默認常量
            "some_path/frozen.pb", # 保存frozen.pb
            False,
            "",
        )
    print("done")


if __name__ == "__main__":
    main()

3. 打印 網絡中節點的名字

import tensorflow as tf


if __name__ == "__main__":
    checkpoint_path = '../model_fintune/ckpt-1400'  
    reader = tf.train.NewCheckpointReader(checkpoint_path)  
    var_to_shape_map = reader.get_variable_to_shape_map()  
    
    for key in var_to_shape_map:  
        print("tensor name: ", key)  
        # print(reader.get_tensor(key))

或者通過

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)

printTensors("path-to-my-pbfile.pb")

4. 兩種方法對比

如果是自己的代碼訓練的模型,有網絡結構,有 ckpt 文件,最好是使用第二種方法,使用起來很靈活,可以進行各種自定義,比如修改輸入輸出的節點名字,網絡有多個路徑的時候可以自定義輸出路徑。第一種方法,應該也能達到第二種方法的效果,因為它們本來就是等價的,可能會有些麻煩。第一種方法的好處就是快,不要去翻那些雜糅在一起的網絡結構。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM