import tensorflow as tf from tensorflow.python.tools import freeze_graph #os.environ['CUDA_VISIBLE_DEVICES']='2' #設置GPU model_path = "D:\\JupyterWorkSpace\\Tensorflow\\Fine-tuning\\tensorflow-resnet-pretrained-20160509\\ResNet-L152.ckpt" #設置model的路徑 def main(): tf.reset_default_graph() saver = tf.train.import_meta_graph("D:\\JupyterWorkSpace\\Tensorflow\\Fine-tuning\\tensorflow-resnet-pretrained-20160509\\ResNet-L152.meta") #flow = tf.cast(flow, tf.uint8, 'out') #設置輸出類型以及輸出的接口名字,為了之后的調用pb的時候使用 with tf.Session() as sess: saver.restore(sess, model_path) #保存圖 tf.train.write_graph(sess.graph_def, './ResNet_L152_retrain/pb_model', 'model_ResNet_L152.pb') #把圖和參數結構一起 freeze_graph.freeze_graph('ResNet_L152_retrain/pb_model/model_ResNet_L152.pb', '', False, model_path, 'fc/xw_plus_b', 'save/restore_all', 'save/Const:0', 'ResNet_L152_retrain/pb_model/frozen_model_ResNet_L152.pb', False, "") print("done") if __name__ == '__main__': main()
總共有11個參數,一個個介紹下(必選: 表示必須有值;可選: 表示可以為空):
1、input_graph:(必選)模型文件,可以是二進制的pb文件,或文本的meta文件,用input_binary來指定區分(見下面說明)
2、input_saver:(可選)Saver解析器。保存模型和權限時,Saver也可以自身序列化保存,以便在加載時應用合適的版本。主要用於版本不兼容時使用。可以為空,為空時用當前版本的Saver。
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進制,為false時,input_graph為文件。默認False
4、input_checkpoint:(必選)檢查點數據文件。訓練時,給Saver用於保存權重、偏置等變量值。這時用於模型恢復變量值。
5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用於指定輸出節點,將沒有在輸出線上的其它節點剔除。
6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。默認:save/restore_all
7、filename_tensor_name:(可選)已棄用。默認:save/Const:0
8、output_graph:(必選)用來保存整合后的模型輸出文件。
9、clear_devices:(可選),默認True。指定是否清除訓練時節點指定的運算設備(如cpu、gpu、tpu。cpu是默認)
10、initializer_nodes:(可選)默認空。權限加載后,可通過此參數來指定需要初始化的節點,用逗號分隔多個節點名字。
11、variable_names_blacklist:(可先)默認空。變量黑名單,用於指定不用恢復值的變量,用逗號分隔多個變量名字。