TensorFlow拾遺(二) pb文件的生成


 


(一)獲取ckpt各節點名稱

方式一

方式二

(二)將ckpt轉化為pb文件

附錄:ckpt文件形式    


   將模型移植到諸如Android,FPGA等移動端時,需要模型的pb文件。深度學習框架會將模型權重保存為自身的格式,如.ckpt(tensorflow) .h5(tf.keras/Keras) .pt(pytorch)。此時,便需要對權重文件的格式進行轉換。

(一)獲取ckpt各節點名稱

  在進行ckpt權重文件轉化為pb文件時,需要獲取到模型的輸出節點名稱,便於從相應的op中獲得op的輸出。

方式一

  執行下列腳本,可以獲得模型各節點變量的txt文件。需要將checkpoint_path更改為自己的模型路徑,本篇使用的是SiamMask算法的模型。

 1 # -*- coding: utf-8 -*-
 2 # @Time    : 2020/7/14 下午3:11
 3 # @Author  : monologuesmw   
 4 # @Email   : monologuesmw@163.com
 5 # @File    : node_name.py
 6 # @Software: PyCharm
 7 from tensorflow.python import pywrap_tensorflow
 8 import os
 9 def get_node_name(checkpoint_path):
10     reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
11     var_to_shape_map=reader.get_variable_to_shape_map()
12     node_name = []
13     for key in var_to_shape_map:
14         node_name.append(key)
15     return node_name
16 
17 
18 def txt_save(filename, data):
19     file = open(filename, 'a')
20     for i in data:
21         s = str(i) + '\n'
22         file.write(s)
23     file.close()
24 
25 
26 if __name__ == '__main__':
27     checkpoint_path = 'ckpt/model.ckpt'
28     file_name = 'node_name.txt'
29     node_name = get_node_name(checkpoint_path)
30     txt_save(file_name, node_name)

  對這種方式獲得的節點名稱感興趣的可以猛戳這里。不過,從名稱上觀察,這種方式獲得的名稱是亂序的,不便於查找需要的輸出節點。

  這種方式獲得的節點名稱如下所示

1 Score/Head/conv2d_1/bias
2 BBox/Head/conv2d_1/bias

方式二

  還可以使用更簡潔的方式,獲取按順序排列的各種op

 1 import tensorflow as tf
 2 def get_op(input_checkpoint):
 3     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 4     graph = tf.get_default_graph()
 5     txt_save_op(graph.get_operations(), "op_name")   # get node name
 6 
 7         
 8 def txt_save_op(data, output_file):
 9     file = open(output_file, 'a')
10     for op in data:
11         s = str(op.name) + '\n'
12         file.write(s)
13     file.close()
14 
15 if __name__ == '__main__':
16     checkpoint_path = 'ckpt/model.ckpt'
17     get_op(checkpoint_path)

  這種方式會獲得tensorflow的所有op。由於op較多,感興趣的可以猛戳這里

  可以獲得兩個輸出節點的名稱:

1 BBox/Head/conv2d_1/BiasAdd
2 Score/Head/conv2d_1/BiasAdd

方式三  

        另外,通過網絡結構變量的方式也可以獲取到網絡輸出節點的名稱,詳情猛戳這里

(二)將ckpt轉化為pb文件

  基於獲得的網絡輸出節點的名稱,便可以將ckpt轉換為pb。

  核心代碼集中在freeze_graph()函數中。

  如下述所示,值得注意的是代碼第15行,多個輸出節點需要用逗號隔開,並且逗號兩側不能有空格。

 1 import tensorflow as tf
 2 from tensorflow import graph_util
 3 import os
 4 
 5 def freeze_graph(input_checkpoint, output_graph):
 6     '''
 7     :param input_checkpoint:
 8     :param output_graph: PB模型保存路徑
 9     :return:
10     '''
11     # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態是否可用
12     # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑
13 
14     # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
15     output_node_names = "BBox/Head/conv2d_1/BiasAdd,Score/Head/conv2d_1/BiasAdd"   # character , would't have space
16     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
17 
18     # graph = tf.get_default_graph()
19     # txt_save_op(graph.get_operations(), "op_name_1")   # get node name
20 
21     with tf.Session() as sess:
22         saver.restore(sess, input_checkpoint)  # 恢復圖並得到數據
23         output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
24             sess=sess,
25             input_graph_def=sess.graph_def,  # 等於:sess.graph_def
26             output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開
27 
28         with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
29             f.write(output_graph_def.SerializeToString())  # 序列化輸出
30 
31 
32 def txt_save(data, output_file):
33     file = open(output_file, 'a')
34     for i in data:
35         s = str(i) + '\n'
36         file.write(s)
37     file.close()
38 
39 
40 def txt_save_op(data, output_file):
41     file = open(output_file, 'a')
42     for op in data:
43         # s1 =str(op.name)
44         # s2 = str(op.values)
45         # s = s1+s2+'\n'
46         s = str(op.name) + '\n'
47         file.write(s)
48     file.close()
49 
50 # 獲取網絡結構變量  可省略
51 def network_param(input_checkpoint, output_file=None):
52     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
53     with tf.Session() as sess:
54         saver.restore(sess, input_checkpoint)
55         variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
56         txt_save(variables, output_file)
57 
58 
59 if __name__ == '__main__':
60     checkpoint_path = 'ckpt/model.ckpt'
61     output_path = 'pb/frozen_moedl.pb'
62     a = os.path.split(output_path)[0]
63     if not os.path.exists(a):
64         os.makedirs(a)
65     # 獲取網絡結構與變量  可省略
66     #output_file = 'network_param.txt'
67     #if not os.path.exists(output_file):
68     #    network_param(checkpoint_path, output_file)
69         
70     freeze_graph(checkpoint_path, output_path)

  其中,注釋掉的部分代碼是打印網絡結構與變量的,也可以獲得網絡輸出節點名稱,詳情猛戳這里

附錄:ckpt文件形式

 

 

參考鏈接:https://blog.csdn.net/guyuealian/article/details/82218092#%C2%A0%20%C2%A0%20%C2%A01%E3%80%81%E8%AE%AD%E7%BB%83%E6%96%B9%E6%B3%95

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM