兩種Tensorflow模型保存的方法


在Tensorflow中,有兩種保存模型的方法:一種是Checkpoint,另一種是Protobuf,也就是PB格式;

一. Checkpoint方法:

   1.保存時使用方法:

                  tf.train.Saver()

             生成四個文件:

       checkpoint                 檢查點文件

                   model.ckpt.data-xxx 參數值

                   model.ckpt.index   各個參數

                   model.ckpt.meta   圖的結構

   2.恢復時使用方法:

     saver.restore() :模型文件依賴Tensorflow,只能在其框架下使用,恢復模型之前需要定義下網絡結構

                  saver=tf.train.import_meta_graph('./ckpt/mode..ckpt.meta')  :直接加載網絡結構,不需要重新定義網絡

二. PB方法:

  1. 保存模型為PB文件(谷歌推薦),具有語言獨立性,可獨立運行,序列化的格式,任何語言可解析它,允許其他語言和框架讀取,訓練和遷移;模型變量是固定的,模型大小會大大減少,適合在手機端運行;

  2. 實現創建模型與使用模型的解耦,使得前向推導Inference代碼統一;

  3. PB文件表示MetaGraph的protocol buffer格式的文件;

  4. GraphDef 不保存任何Variable信息,不能從graph_def 來構建圖並恢復訓練. 

       一般情況下,PB可直接生成;

       當然也可以從checkpoint文件中生成,代碼如下:

 1 output_graph = os.path.join('./checkpoint/','frozen_graph.pb')
 2 input_checkpoint = os.path.join('./checkpoint/','model.ckpt-xxxxx')  #[xxxxxx為訓練生成的step號]
 3 saver = tf.train.import_meta_graph(input_checkpoint+'.meta',clear_devices=True)
 4 graph = tf.get_default_graph()
 5 input_graph_def = graph.as_graph_def
 6 
 7 for op in graph.get_operations():
 8     print("checkpoint2pb",op.name,op.values())
 9 
10 variable_names = [v.name for v in tf.trainable_variables()]
11 pirnt("trainalbe_variables:",variable_names)
12 
13 output_node_name=['fc2/add']  #fc2/add 上面的列表里需要存在該操作
14 
15 with tf.Session() as sess:
16     saver.restore(sess,input_checkpoint)
17 
18     output_graph_def = graph_util.convert_variables_to_constants(sess=sess,
19                     input_graph_def = input_graph_def,
20                     output_node_names = output_node_name)
21 
22     with tf.gfile.GFile(output_graph,"wb") as f:
23         f.write(output_graph_def.SerializeToString())
24 
25 
26     
27  
View Code

 

   


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM