在Tensorflow中,有兩種保存模型的方法:一種是Checkpoint,另一種是Protobuf,也就是PB格式;
一. Checkpoint方法:
1.保存時使用方法:
tf.train.Saver()
生成四個文件:
checkpoint 檢查點文件
model.ckpt.data-xxx 參數值
model.ckpt.index 各個參數
model.ckpt.meta 圖的結構
2.恢復時使用方法:
saver.restore() :模型文件依賴Tensorflow,只能在其框架下使用,恢復模型之前需要定義下網絡結構
saver=tf.train.import_meta_graph('./ckpt/mode..ckpt.meta') :直接加載網絡結構,不需要重新定義網絡
二. PB方法:
1. 保存模型為PB文件(谷歌推薦),具有語言獨立性,可獨立運行,序列化的格式,任何語言可解析它,允許其他語言和框架讀取,訓練和遷移;模型變量是固定的,模型大小會大大減少,適合在手機端運行;
2. 實現創建模型與使用模型的解耦,使得前向推導Inference代碼統一;
3. PB文件表示MetaGraph的protocol buffer格式的文件;
4. GraphDef 不保存任何Variable信息,不能從graph_def 來構建圖並恢復訓練.
一般情況下,PB可直接生成;
當然也可以從checkpoint文件中生成,代碼如下:

1 output_graph = os.path.join('./checkpoint/','frozen_graph.pb') 2 input_checkpoint = os.path.join('./checkpoint/','model.ckpt-xxxxx') #[xxxxxx為訓練生成的step號] 3 saver = tf.train.import_meta_graph(input_checkpoint+'.meta',clear_devices=True) 4 graph = tf.get_default_graph() 5 input_graph_def = graph.as_graph_def 6 7 for op in graph.get_operations(): 8 print("checkpoint2pb",op.name,op.values()) 9 10 variable_names = [v.name for v in tf.trainable_variables()] 11 pirnt("trainalbe_variables:",variable_names) 12 13 output_node_name=['fc2/add'] #fc2/add 上面的列表里需要存在該操作 14 15 with tf.Session() as sess: 16 saver.restore(sess,input_checkpoint) 17 18 output_graph_def = graph_util.convert_variables_to_constants(sess=sess, 19 input_graph_def = input_graph_def, 20 output_node_names = output_node_name) 21 22 with tf.gfile.GFile(output_graph,"wb") as f: 23 f.write(output_graph_def.SerializeToString()) 24 25 26 27