將模型移植到諸如Android,FPGA等移動端時,需要模型的pb文件。深度學習框架會將模型權重保存為自身的格式,如.ckpt(tensorflow) .h5(tf.keras/Keras) .pt(pytorch)。此時,便需要對權重文件的格式進行轉換。
(一)獲取ckpt各節點名稱
在進行ckpt權重文件轉化為pb文件時,需要獲取到模型的輸出節點名稱,便於從相應的op中獲得op的輸出。
方式一
執行下列腳本,可以獲得模型各節點變量的txt文件。需要將checkpoint_path更改為自己的模型路徑,本篇使用的是SiamMask算法的模型。
1 # -*- coding: utf-8 -*- 2 # @Time : 2020/7/14 下午3:11 3 # @Author : monologuesmw 4 # @Email : monologuesmw@163.com 5 # @File : node_name.py 6 # @Software: PyCharm 7 from tensorflow.python import pywrap_tensorflow 8 import os 9 def get_node_name(checkpoint_path): 10 reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 11 var_to_shape_map=reader.get_variable_to_shape_map() 12 node_name = [] 13 for key in var_to_shape_map: 14 node_name.append(key) 15 return node_name 16 17 18 def txt_save(filename, data): 19 file = open(filename, 'a') 20 for i in data: 21 s = str(i) + '\n' 22 file.write(s) 23 file.close() 24 25 26 if __name__ == '__main__': 27 checkpoint_path = 'ckpt/model.ckpt' 28 file_name = 'node_name.txt' 29 node_name = get_node_name(checkpoint_path) 30 txt_save(file_name, node_name)
對這種方式獲得的節點名稱感興趣的可以猛戳這里。不過,從名稱上觀察,這種方式獲得的名稱是亂序的,不便於查找需要的輸出節點。
這種方式獲得的節點名稱如下所示
1 Score/Head/conv2d_1/bias 2 BBox/Head/conv2d_1/bias
方式二
還可以使用更簡潔的方式,獲取按順序排列的各種op:
1 import tensorflow as tf 2 def get_op(input_checkpoint): 3 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 4 graph = tf.get_default_graph() 5 txt_save_op(graph.get_operations(), "op_name") # get node name 6 7 8 def txt_save_op(data, output_file): 9 file = open(output_file, 'a') 10 for op in data: 11 s = str(op.name) + '\n' 12 file.write(s) 13 file.close() 14 15 if __name__ == '__main__': 16 checkpoint_path = 'ckpt/model.ckpt' 17 get_op(checkpoint_path)
這種方式會獲得tensorflow的所有op。由於op較多,感興趣的可以猛戳這里。
可以獲得兩個輸出節點的名稱:
1 BBox/Head/conv2d_1/BiasAdd 2 Score/Head/conv2d_1/BiasAdd
方式三
另外,通過網絡結構變量的方式也可以獲取到網絡輸出節點的名稱,詳情猛戳這里。
(二)將ckpt轉化為pb文件
基於獲得的網絡輸出節點的名稱,便可以將ckpt轉換為pb。
核心代碼集中在freeze_graph()函數中。
如下述所示,值得注意的是代碼第15行,多個輸出節點需要用逗號隔開,並且逗號兩側不能有空格。
1 import tensorflow as tf 2 from tensorflow import graph_util 3 import os 4 5 def freeze_graph(input_checkpoint, output_graph): 6 ''' 7 :param input_checkpoint: 8 :param output_graph: PB模型保存路徑 9 :return: 10 ''' 11 # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態是否可用 12 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑 13 14 # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點 15 output_node_names = "BBox/Head/conv2d_1/BiasAdd,Score/Head/conv2d_1/BiasAdd" # character , would't have space 16 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 17 18 # graph = tf.get_default_graph() 19 # txt_save_op(graph.get_operations(), "op_name_1") # get node name 20 21 with tf.Session() as sess: 22 saver.restore(sess, input_checkpoint) # 恢復圖並得到數據 23 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定 24 sess=sess, 25 input_graph_def=sess.graph_def, # 等於:sess.graph_def 26 output_node_names=output_node_names.split(",")) # 如果有多個輸出節點,以逗號隔開 27 28 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 29 f.write(output_graph_def.SerializeToString()) # 序列化輸出 30 31 32 def txt_save(data, output_file): 33 file = open(output_file, 'a') 34 for i in data: 35 s = str(i) + '\n' 36 file.write(s) 37 file.close() 38 39 40 def txt_save_op(data, output_file): 41 file = open(output_file, 'a') 42 for op in data: 43 # s1 =str(op.name) 44 # s2 = str(op.values) 45 # s = s1+s2+'\n' 46 s = str(op.name) + '\n' 47 file.write(s) 48 file.close() 49 50 # 獲取網絡結構變量 可省略 51 def network_param(input_checkpoint, output_file=None): 52 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 53 with tf.Session() as sess: 54 saver.restore(sess, input_checkpoint) 55 variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 56 txt_save(variables, output_file) 57 58 59 if __name__ == '__main__': 60 checkpoint_path = 'ckpt/model.ckpt' 61 output_path = 'pb/frozen_moedl.pb' 62 a = os.path.split(output_path)[0] 63 if not os.path.exists(a): 64 os.makedirs(a) 65 # 獲取網絡結構與變量 可省略 66 #output_file = 'network_param.txt' 67 #if not os.path.exists(output_file): 68 # network_param(checkpoint_path, output_file) 69 70 freeze_graph(checkpoint_path, output_path)
其中,注釋掉的部分代碼是打印網絡結構與變量的,也可以獲得網絡輸出節點名稱,詳情猛戳這里。
附錄:ckpt文件形式