TensorFlow拾遗(二) pb文件的生成


 


(一)获取ckpt各节点名称

方式一

方式二

(二)将ckpt转化为pb文件

附录:ckpt文件形式    


   将模型移植到诸如Android,FPGA等移动端时,需要模型的pb文件。深度学习框架会将模型权重保存为自身的格式,如.ckpt(tensorflow) .h5(tf.keras/Keras) .pt(pytorch)。此时,便需要对权重文件的格式进行转换。

(一)获取ckpt各节点名称

  在进行ckpt权重文件转化为pb文件时,需要获取到模型的输出节点名称,便于从相应的op中获得op的输出。

方式一

  执行下列脚本,可以获得模型各节点变量的txt文件。需要将checkpoint_path更改为自己的模型路径,本篇使用的是SiamMask算法的模型。

 1 # -*- coding: utf-8 -*-
 2 # @Time    : 2020/7/14 下午3:11
 3 # @Author  : monologuesmw   
 4 # @Email   : monologuesmw@163.com
 5 # @File    : node_name.py
 6 # @Software: PyCharm
 7 from tensorflow.python import pywrap_tensorflow
 8 import os
 9 def get_node_name(checkpoint_path):
10     reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
11     var_to_shape_map=reader.get_variable_to_shape_map()
12     node_name = []
13     for key in var_to_shape_map:
14         node_name.append(key)
15     return node_name
16 
17 
18 def txt_save(filename, data):
19     file = open(filename, 'a')
20     for i in data:
21         s = str(i) + '\n'
22         file.write(s)
23     file.close()
24 
25 
26 if __name__ == '__main__':
27     checkpoint_path = 'ckpt/model.ckpt'
28     file_name = 'node_name.txt'
29     node_name = get_node_name(checkpoint_path)
30     txt_save(file_name, node_name)

  对这种方式获得的节点名称感兴趣的可以猛戳这里。不过,从名称上观察,这种方式获得的名称是乱序的,不便于查找需要的输出节点。

  这种方式获得的节点名称如下所示

1 Score/Head/conv2d_1/bias
2 BBox/Head/conv2d_1/bias

方式二

  还可以使用更简洁的方式,获取按顺序排列的各种op

 1 import tensorflow as tf
 2 def get_op(input_checkpoint):
 3     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 4     graph = tf.get_default_graph()
 5     txt_save_op(graph.get_operations(), "op_name")   # get node name
 6 
 7         
 8 def txt_save_op(data, output_file):
 9     file = open(output_file, 'a')
10     for op in data:
11         s = str(op.name) + '\n'
12         file.write(s)
13     file.close()
14 
15 if __name__ == '__main__':
16     checkpoint_path = 'ckpt/model.ckpt'
17     get_op(checkpoint_path)

  这种方式会获得tensorflow的所有op。由于op较多,感兴趣的可以猛戳这里

  可以获得两个输出节点的名称:

1 BBox/Head/conv2d_1/BiasAdd
2 Score/Head/conv2d_1/BiasAdd

方式三  

        另外,通过网络结构变量的方式也可以获取到网络输出节点的名称,详情猛戳这里

(二)将ckpt转化为pb文件

  基于获得的网络输出节点的名称,便可以将ckpt转换为pb。

  核心代码集中在freeze_graph()函数中。

  如下述所示,值得注意的是代码第15行,多个输出节点需要用逗号隔开,并且逗号两侧不能有空格。

 1 import tensorflow as tf
 2 from tensorflow import graph_util
 3 import os
 4 
 5 def freeze_graph(input_checkpoint, output_graph):
 6     '''
 7     :param input_checkpoint:
 8     :param output_graph: PB模型保存路径
 9     :return:
10     '''
11     # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
12     # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
13 
14     # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
15     output_node_names = "BBox/Head/conv2d_1/BiasAdd,Score/Head/conv2d_1/BiasAdd"   # character , would't have space
16     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
17 
18     # graph = tf.get_default_graph()
19     # txt_save_op(graph.get_operations(), "op_name_1")   # get node name
20 
21     with tf.Session() as sess:
22         saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
23         output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
24             sess=sess,
25             input_graph_def=sess.graph_def,  # 等于:sess.graph_def
26             output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开
27 
28         with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
29             f.write(output_graph_def.SerializeToString())  # 序列化输出
30 
31 
32 def txt_save(data, output_file):
33     file = open(output_file, 'a')
34     for i in data:
35         s = str(i) + '\n'
36         file.write(s)
37     file.close()
38 
39 
40 def txt_save_op(data, output_file):
41     file = open(output_file, 'a')
42     for op in data:
43         # s1 =str(op.name)
44         # s2 = str(op.values)
45         # s = s1+s2+'\n'
46         s = str(op.name) + '\n'
47         file.write(s)
48     file.close()
49 
50 # 获取网络结构变量  可省略
51 def network_param(input_checkpoint, output_file=None):
52     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
53     with tf.Session() as sess:
54         saver.restore(sess, input_checkpoint)
55         variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
56         txt_save(variables, output_file)
57 
58 
59 if __name__ == '__main__':
60     checkpoint_path = 'ckpt/model.ckpt'
61     output_path = 'pb/frozen_moedl.pb'
62     a = os.path.split(output_path)[0]
63     if not os.path.exists(a):
64         os.makedirs(a)
65     # 获取网络结构与变量  可省略
66     #output_file = 'network_param.txt'
67     #if not os.path.exists(output_file):
68     #    network_param(checkpoint_path, output_file)
69         
70     freeze_graph(checkpoint_path, output_path)

  其中,注释掉的部分代码是打印网络结构与变量的,也可以获得网络输出节点名称,详情猛戳这里

附录:ckpt文件形式

 

 

参考链接:https://blog.csdn.net/guyuealian/article/details/82218092#%C2%A0%20%C2%A0%20%C2%A01%E3%80%81%E8%AE%AD%E7%BB%83%E6%96%B9%E6%B3%95

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM