""" @Author: Qiangz @Date: 2019/7/5 @Description: """ import tensorflow as tf from tensorflow.python.framework import graph_util import argparse tf.reset_default_graph() # 重置計算圖 def network_structure(args): model_path = args.model+'.pb' with tf.Session() as sess: tf.global_variables_initializer().run() output_graph_def = tf.GraphDef() # 獲得默認的圖 graph = tf.get_default_graph() with open(model_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") # 得到當前圖有幾個操作節點 print("%d ops in the final graph." % len(output_graph_def.node)) tensor_name = [tensor.name for tensor in output_graph_def.node] print(tensor_name) print('---------------------------') # 在log_graph文件夾下生產日志文件,可以在tensorboard中可視化模型 summaryWriter = tf.summary.FileWriter('log_graph_'+args.model, graph) cnt = 0 for op in graph.get_operations(): # print出tensor的name和值 print(op.name, op.values()) cnt += 1 if args.n: if cnt == args.n: break """ 可視化 tensorboard --logdir="log_graph/" """ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, help="model name to look") parser.add_argument('--n', type=int, help='the number of first several tensor name to look') # 當tensor_name過多 args = parser.parse_args() network_structure(args)
運行
python model_structure.py --model facenet --n 10