把ResNet-L152模型的ckpt文件轉化為pb文件


import tensorflow as tf
from tensorflow.python.tools import freeze_graph

#os.environ['CUDA_VISIBLE_DEVICES']='2'  #設置GPU
model_path  = "D:\\JupyterWorkSpace\\Tensorflow\\Fine-tuning\\tensorflow-resnet-pretrained-20160509\\ResNet-L152.ckpt" #設置model的路徑

def main():
    tf.reset_default_graph()
    saver = tf.train.import_meta_graph("D:\\JupyterWorkSpace\\Tensorflow\\Fine-tuning\\tensorflow-resnet-pretrained-20160509\\ResNet-L152.meta")
    #flow = tf.cast(flow, tf.uint8, 'out') #設置輸出類型以及輸出的接口名字,為了之后的調用pb的時候使用
    with tf.Session() as sess:
        saver.restore(sess, model_path)
        #保存圖
        tf.train.write_graph(sess.graph_def, './ResNet_L152_retrain/pb_model', 'model_ResNet_L152.pb')
        #把圖和參數結構一起
        freeze_graph.freeze_graph('ResNet_L152_retrain/pb_model/model_ResNet_L152.pb',
                                  '',
                                  False,
                                  model_path,
                                  'fc/xw_plus_b',
                                  'save/restore_all',
                                  'save/Const:0',
                                  'ResNet_L152_retrain/pb_model/frozen_model_ResNet_L152.pb',
                                  False,
                                  "")
    print("done")
    
if __name__ == '__main__':
    main()

  總共有11個參數,一個個介紹下(必選: 表示必須有值;可選: 表示可以為空): 
1、input_graph:(必選)模型文件,可以是二進制的pb文件,或文本的meta文件,用input_binary來指定區分(見下面說明) 
2、input_saver:(可選)Saver解析器。保存模型和權限時,Saver也可以自身序列化保存,以便在加載時應用合適的版本。主要用於版本不兼容時使用。可以為空,為空時用當前版本的Saver。 
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進制,為false時,input_graph為文件。默認False 
4、input_checkpoint:(必選)檢查點數據文件。訓練時,給Saver用於保存權重、偏置等變量值。這時用於模型恢復變量值。 
5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用於指定輸出節點,將沒有在輸出線上的其它節點剔除。 
6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。默認:save/restore_all 
7、filename_tensor_name:(可選)已棄用。默認:save/Const:0 
8、output_graph:(必選)用來保存整合后的模型輸出文件。 
9、clear_devices:(可選),默認True。指定是否清除訓練時節點指定的運算設備(如cpu、gpu、tpu。cpu是默認) 
10、initializer_nodes:(可選)默認空。權限加載后,可通過此參數來指定需要初始化的節點,用逗號分隔多個節點名字。 
11、variable_names_blacklist:(可先)默認空。變量黑名單,用於指定不用恢復值的變量,用逗號分隔多個變量名字。 

參考:http://blog.csdn.net/yjl9122/article/details/78341689


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM