TensorFlow拾遗(一) 打印网络结构与变量


  使用tensorflow搭建网络之后,如果可视化一下网络的结构与变量,会对网络结构有一个更直观的了解。

  另外,这种方式也可以获得网络输出节点名称,便于pb文件的生成。

  在许多源码中都会包含这一操作,只不过大多可能并没有打印出来

variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # type:list
 1 import tensorflow as tf
 2 import os
 3 def txt_save(data, output_file):
 4     file = open(output_file, 'a')
 5     for i in data:
 6         s = str(i) + '\n'
 7         file.write(s)
 8     file.close()
 9 
10 
11 def network_param(input_checkpoint, output_file=None):
12     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
13     with tf.Session() as sess:
14         saver.restore(sess, input_checkpoint)
15         variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
16         for i in variable:
17             print(i)     # 打印
18         txt_save(variables, output_file)  # 保存txt   二选一
19 
20 if __name__ == '__main__':
21     checkpoint_path = 'ckpt/model.ckpt'
22     output_file = 'network_param.txt'
23     if not os.path.exists(output_file):
24         network_param(checkpoint_path, output_file)

  获得的txt文件部分如下所示,详细的txt文件猛戳这里

<tf.Variable 'Layer1/conv2d/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref>
<tf.Variable 'Layer1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_2/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_2/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/conv2d/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d/kernel:0' shape=(1, 1, 256, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
网络结构输出部分展示

  其实,在打印网络结构与变量获得的结果中,也可以获得输出节点的名称,如下所示:

1 Score/Head/conv2d_1/BiasAdd 
2 BBox/Head/conv2d_1/BiasAdd

 

 

  


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM