TensorFlow拾遺(一) 打印網絡結構與變量


  使用tensorflow搭建網絡之后,如果可視化一下網絡的結構與變量,會對網絡結構有一個更直觀的了解。

  另外,這種方式也可以獲得網絡輸出節點名稱,便於pb文件的生成。

  在許多源碼中都會包含這一操作,只不過大多可能並沒有打印出來

variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # type:list
 1 import tensorflow as tf
 2 import os
 3 def txt_save(data, output_file):
 4     file = open(output_file, 'a')
 5     for i in data:
 6         s = str(i) + '\n'
 7         file.write(s)
 8     file.close()
 9 
10 
11 def network_param(input_checkpoint, output_file=None):
12     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
13     with tf.Session() as sess:
14         saver.restore(sess, input_checkpoint)
15         variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
16         for i in variable:
17             print(i)     # 打印
18         txt_save(variables, output_file)  # 保存txt   二選一
19 
20 if __name__ == '__main__':
21     checkpoint_path = 'ckpt/model.ckpt'
22     output_file = 'network_param.txt'
23     if not os.path.exists(output_file):
24         network_param(checkpoint_path, output_file)

  獲得的txt文件部分如下所示,詳細的txt文件猛戳這里

<tf.Variable 'Layer1/conv2d/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref>
<tf.Variable 'Layer1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_2/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/batch_normalization_2/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/conv2d/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/gamma:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/beta:0' shape=(256,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d/kernel:0' shape=(1, 1, 256, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'Layer2/Block_1/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
網絡結構輸出部分展示

  其實,在打印網絡結構與變量獲得的結果中,也可以獲得輸出節點的名稱,如下所示:

1 Score/Head/conv2d_1/BiasAdd 
2 BBox/Head/conv2d_1/BiasAdd

 

 

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM