使用tensorflow搭建網絡之后,如果可視化一下網絡的結構與變量,會對網絡結構有一個更直觀的了解。
另外,這種方式也可以獲得網絡輸出節點名稱,便於pb文件的生成。
在許多源碼中都會包含這一操作,只不過大多可能並沒有打印出來
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # type:list
1 import tensorflow as tf 2 import os 3 def txt_save(data, output_file): 4 file = open(output_file, 'a') 5 for i in data: 6 s = str(i) + '\n' 7 file.write(s) 8 file.close() 9 10 11 def network_param(input_checkpoint, output_file=None): 12 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 13 with tf.Session() as sess: 14 saver.restore(sess, input_checkpoint) 15 variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 16 for i in variable: 17 print(i) # 打印 18 txt_save(variables, output_file) # 保存txt 二選一 19 20 if __name__ == '__main__': 21 checkpoint_path = 'ckpt/model.ckpt' 22 output_file = 'network_param.txt' 23 if not os.path.exists(output_file): 24 network_param(checkpoint_path, output_file)
獲得的txt文件部分如下所示,詳細的txt文件猛戳這里:

<tf.Variable 'Layer1/conv2d/kernel:0' shape=(7, 7, 3, 64) dtype=float32_ref> <tf.Variable 'Layer1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/conv2d/kernel:0' shape=(1, 1, 64, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_2/gamma:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/batch_normalization_2/beta:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/Downsample/conv2d/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/gamma:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_0/Downsample/batch_normalization/beta:0' shape=(256,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/conv2d/kernel:0' shape=(1, 1, 256, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/conv2d_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization_1/gamma:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/batch_normalization_1/beta:0' shape=(64,) dtype=float32_ref> <tf.Variable 'Layer2/Block_1/conv2d_2/kernel:0' shape=(1, 1, 64, 256) dtype=float32_ref>
其實,在打印網絡結構與變量獲得的結果中,也可以獲得輸出節點的名稱,如下所示:
1 Score/Head/conv2d_1/BiasAdd 2 BBox/Head/conv2d_1/BiasAdd