一、模型的保存
使用tensorflow訓練模型的過程中,需要適時對模型進行保存,以及對保存的模型進行restore,以便后續對模型進行處理。如:測試、部署、拿別的模型進行fine-tune等。
保存模型是整個內容的第一步,操作十分簡單,只需要創建一個saver,並在一個Session里完成保存。
saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess, model_name)
以上代碼在0.11以下版本的tensorflow里會保存與下面類似的3個文件
checkpoint model.ckpt-1000.meta model.ckpt-1000.ckpt
在0.11及以上版本的tensorflow里則會保存與下類似的4個文件
checkpoint model.ckpt-1000.index model.ckpt-1000.data-00000-of-00001 model.ckpt-1000.meta
其中checkpoint列出保存的所有模型以及最近的模型;meta文件是模型定義的內容;ckpt(或data和index)文件是保存的模型數據。
除了上面最簡單的保存方式,也可以指定保存的步數,多長時間保存一次,磁盤上最多保存幾個模型(將前面的刪除以保持固定個數),需要做的是在創建saver時指定參數
saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
其中,savable_variables指定待保存的變量,比如指定為tf.global_variables()保存所有global變量;指定為[v1, v2]保存v1和v2兩個變量,如果省略,則保存所有。
max_to_keep指定磁盤上最多保存有幾個模型。
keep_checkpoint_every_n_hours指定多少小時保存一次。
保存模型時指定參數
saver.save(sess, 'model_name', global_step=step, write_meta_graph=False)
其中,可以指定模型文件名,步數,write_meta_graph則用來指定是否保存meta文件記錄graph,等等。
二、模型的恢復及查看模型參數
with tf.Session() as sess: # 加載模型定義的graph saver = tf.train.import_meta_graph('model.ckpt-1000.meta') # 方式一:加載指定文件夾下最近保存的一個模型的數據 saver.restore(sess, tf.train.latest_checkpoint('./')) # 方式二:指定具體某個數據,需要注意的是,指定的文件不要包含后綴 # saver.restore(sess, os.path.join(path, 'model.ckpt-1000')) # 查看模型中的trainable variables tvs = [v for v in tf.trainable_variables()] for v in tvs: print(v.name) print(sess.run(v)) # 查看模型中的所有tensor或者operations gv = [v for v in tf.global_variables()] for v in gv: print(v.name) # 獲得幾乎所有的operations相關的tensor ops = [o for o in sess.graph.get_operations()] for o in ops: print(o.name)
說明:
1、global_variables()比trainable_variables()多了一些非trainable的變量,比如定義時指定為trainable=False的變量,或Optimizer相關的變量。
2、sess.graph.get_operations()可以換為tf.get_default_graph().get_operations(),二者區別無非是graph明確的時候可以直接使用前者,否則需要使用后者。
三、將ckpt轉化為pb
freeze_graph就是將模型固化,具體說就是將訓練數據和模型固化成pb文件。
參數: (必選: 表示必須有值;可選: 表示可以為空):
1、input_graph:(必選)模型文件,可以是二進制的pb文件,或文本的meta文件,用input_binary來指定區分(見下面說明)
2、input_saver:(可選)Saver解析器。保存模型和權限時,Saver也可以自身序列化保存,以便在加載時應用合適的版本。主要用於版本不兼容時使用。可以為空,為空時用當前版本的Saver。
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進制,為false時,input_graph為文件。默認False
4、input_checkpoint:(必選)檢查點數據文件。訓練時,給Saver用於保存權重、偏置等變量值。這時用於模型恢復變量值。
5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用於指定輸出節點,將沒有在輸出線上的其它節點剔除。
6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。默認:save/restore_all
7、filename_tensor_name:(可選)已棄用。默認:save/Const:0
8、output_graph:(必選)用來保存整合后的模型輸出文件。
9、clear_devices:(可選),默認True。指定是否清除訓練時節點指定的運算設備(如cpu、gpu、tpu。cpu是默認)
10、initializer_nodes:(可選)默認空。權限加載后,可通過此參數來指定需要初始化的節點,用逗號分隔多個節點名字。
11、variable_names_blacklist:(可先)默認空。變量黑名單,用於指定不用恢復值的變量,用逗號分隔多個變量名字。
if __name__ == '__main__': args = parse_args() # model path demonet = args.demo_net dataset = args.dataset tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0]) if not os.path.isfile(tfmodel + '.meta'): print(tfmodel) raise IOError(('{:s} not found.\nDid you download the proper networks from ' 'our server and place them properly?').format(tfmodel + '.meta')) # set config tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True # init session sess = tf.Session(config=tfconfig) # load network if demonet == 'vgg16': net = vgg16(batch_size=1) else: raise NotImplementedError net.create_architecture(sess, "TEST", 4, tag='default', anchor_scales=[8, 16, 32]) saver = tf.train.Saver() saver.restore(sess, tfmodel) # 保存圖 tf.train.write_graph(sess.graph_def, 'pb/pb_model', 'model.pb') # 把圖和參數結構一起 freeze_graph.freeze_graph('pb/pb_model/model.pb', '', False, tfmodel, 'vgg_16/cls_score/BiasAdd,vgg_16/cls_prob,vgg_16/bbox_pred/BiasAdd,vgg_16/rois/PyFunc', 'save/restore_all', 'save/Const:0', 'pb/pb_model/frozen_model.pb', False, "")