将模型移植到诸如Android,FPGA等移动端时,需要模型的pb文件。深度学习框架会将模型权重保存为自身的格式,如.ckpt(tensorflow) .h5(tf.keras/Keras) .pt(pytorch)。此时,便需要对权重文件的格式进行转换。
(一)获取ckpt各节点名称
在进行ckpt权重文件转化为pb文件时,需要获取到模型的输出节点名称,便于从相应的op中获得op的输出。
方式一
执行下列脚本,可以获得模型各节点变量的txt文件。需要将checkpoint_path更改为自己的模型路径,本篇使用的是SiamMask算法的模型。
1 # -*- coding: utf-8 -*- 2 # @Time : 2020/7/14 下午3:11 3 # @Author : monologuesmw 4 # @Email : monologuesmw@163.com 5 # @File : node_name.py 6 # @Software: PyCharm 7 from tensorflow.python import pywrap_tensorflow 8 import os 9 def get_node_name(checkpoint_path): 10 reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 11 var_to_shape_map=reader.get_variable_to_shape_map() 12 node_name = [] 13 for key in var_to_shape_map: 14 node_name.append(key) 15 return node_name 16 17 18 def txt_save(filename, data): 19 file = open(filename, 'a') 20 for i in data: 21 s = str(i) + '\n' 22 file.write(s) 23 file.close() 24 25 26 if __name__ == '__main__': 27 checkpoint_path = 'ckpt/model.ckpt' 28 file_name = 'node_name.txt' 29 node_name = get_node_name(checkpoint_path) 30 txt_save(file_name, node_name)
对这种方式获得的节点名称感兴趣的可以猛戳这里。不过,从名称上观察,这种方式获得的名称是乱序的,不便于查找需要的输出节点。
这种方式获得的节点名称如下所示
1 Score/Head/conv2d_1/bias 2 BBox/Head/conv2d_1/bias
方式二
还可以使用更简洁的方式,获取按顺序排列的各种op:
1 import tensorflow as tf 2 def get_op(input_checkpoint): 3 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 4 graph = tf.get_default_graph() 5 txt_save_op(graph.get_operations(), "op_name") # get node name 6 7 8 def txt_save_op(data, output_file): 9 file = open(output_file, 'a') 10 for op in data: 11 s = str(op.name) + '\n' 12 file.write(s) 13 file.close() 14 15 if __name__ == '__main__': 16 checkpoint_path = 'ckpt/model.ckpt' 17 get_op(checkpoint_path)
这种方式会获得tensorflow的所有op。由于op较多,感兴趣的可以猛戳这里。
可以获得两个输出节点的名称:
1 BBox/Head/conv2d_1/BiasAdd 2 Score/Head/conv2d_1/BiasAdd
方式三
另外,通过网络结构变量的方式也可以获取到网络输出节点的名称,详情猛戳这里。
(二)将ckpt转化为pb文件
基于获得的网络输出节点的名称,便可以将ckpt转换为pb。
核心代码集中在freeze_graph()函数中。
如下述所示,值得注意的是代码第15行,多个输出节点需要用逗号隔开,并且逗号两侧不能有空格。
1 import tensorflow as tf 2 from tensorflow import graph_util 3 import os 4 5 def freeze_graph(input_checkpoint, output_graph): 6 ''' 7 :param input_checkpoint: 8 :param output_graph: PB模型保存路径 9 :return: 10 ''' 11 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 12 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 13 14 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 15 output_node_names = "BBox/Head/conv2d_1/BiasAdd,Score/Head/conv2d_1/BiasAdd" # character , would't have space 16 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 17 18 # graph = tf.get_default_graph() 19 # txt_save_op(graph.get_operations(), "op_name_1") # get node name 20 21 with tf.Session() as sess: 22 saver.restore(sess, input_checkpoint) # 恢复图并得到数据 23 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 24 sess=sess, 25 input_graph_def=sess.graph_def, # 等于:sess.graph_def 26 output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 27 28 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 29 f.write(output_graph_def.SerializeToString()) # 序列化输出 30 31 32 def txt_save(data, output_file): 33 file = open(output_file, 'a') 34 for i in data: 35 s = str(i) + '\n' 36 file.write(s) 37 file.close() 38 39 40 def txt_save_op(data, output_file): 41 file = open(output_file, 'a') 42 for op in data: 43 # s1 =str(op.name) 44 # s2 = str(op.values) 45 # s = s1+s2+'\n' 46 s = str(op.name) + '\n' 47 file.write(s) 48 file.close() 49 50 # 获取网络结构变量 可省略 51 def network_param(input_checkpoint, output_file=None): 52 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 53 with tf.Session() as sess: 54 saver.restore(sess, input_checkpoint) 55 variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 56 txt_save(variables, output_file) 57 58 59 if __name__ == '__main__': 60 checkpoint_path = 'ckpt/model.ckpt' 61 output_path = 'pb/frozen_moedl.pb' 62 a = os.path.split(output_path)[0] 63 if not os.path.exists(a): 64 os.makedirs(a) 65 # 获取网络结构与变量 可省略 66 #output_file = 'network_param.txt' 67 #if not os.path.exists(output_file): 68 # network_param(checkpoint_path, output_file) 69 70 freeze_graph(checkpoint_path, output_path)
其中,注释掉的部分代码是打印网络结构与变量的,也可以获得网络输出节点名称,详情猛戳这里。
附录:ckpt文件形式