tensorflow模型的保存與恢復,以及ckpt到pb的轉化


一、模型的保存

使用tensorflow訓練模型的過程中,需要適時對模型進行保存,以及對保存的模型進行restore,以便后續對模型進行處理。如:測試、部署、拿別的模型進行fine-tune等。

保存模型是整個內容的第一步,操作十分簡單,只需要創建一個saver,並在一個Session里完成保存。

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess, model_name)

以上代碼在0.11以下版本的tensorflow里會保存與下面類似的3個文件

checkpoint
model.ckpt-1000.meta
model.ckpt-1000.ckpt

在0.11及以上版本的tensorflow里則會保存與下類似的4個文件

checkpoint
model.ckpt-1000.index
model.ckpt-1000.data-00000-of-00001
model.ckpt-1000.meta

其中checkpoint列出保存的所有模型以及最近的模型;meta文件是模型定義的內容;ckpt(或data和index)文件是保存的模型數據。

除了上面最簡單的保存方式,也可以指定保存的步數,多長時間保存一次,磁盤上最多保存幾個模型(將前面的刪除以保持固定個數),需要做的是在創建saver時指定參數

saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

其中,savable_variables指定待保存的變量,比如指定為tf.global_variables()保存所有global變量;指定為[v1, v2]保存v1和v2兩個變量,如果省略,則保存所有。

max_to_keep指定磁盤上最多保存有幾個模型。

keep_checkpoint_every_n_hours指定多少小時保存一次。

保存模型時指定參數

saver.save(sess, 'model_name', global_step=step, write_meta_graph=False)

其中,可以指定模型文件名,步數,write_meta_graph則用來指定是否保存meta文件記錄graph,等等。

 

二、模型的恢復及查看模型參數

with tf.Session() as sess:
    # 加載模型定義的graph
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
    # 方式一:加載指定文件夾下最近保存的一個模型的數據
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # 方式二:指定具體某個數據,需要注意的是,指定的文件不要包含后綴
    # saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))

    # 查看模型中的trainable variables
    tvs = [v for v in tf.trainable_variables()]
    for v in tvs:
        print(v.name)
        print(sess.run(v))

    # 查看模型中的所有tensor或者operations
    gv = [v for v in tf.global_variables()]
    for v in gv:
        print(v.name)

    # 獲得幾乎所有的operations相關的tensor
    ops = [o for o in sess.graph.get_operations()]
    for o in ops:
        print(o.name)

說明:

1、global_variables()比trainable_variables()多了一些非trainable的變量,比如定義時指定為trainable=False的變量,或Optimizer相關的變量。

2、sess.graph.get_operations()可以換為tf.get_default_graph().get_operations(),二者區別無非是graph明確的時候可以直接使用前者,否則需要使用后者。

 

三、將ckpt轉化為pb

freeze_graph就是將模型固化,具體說就是將訓練數據和模型固化成pb文件。

參數: (必選: 表示必須有值;可選: 表示可以為空):
1、input_graph:(必選)模型文件,可以是二進制的pb文件,或文本的meta文件,用input_binary來指定區分(見下面說明)
2、input_saver:(可選)Saver解析器。保存模型和權限時,Saver也可以自身序列化保存,以便在加載時應用合適的版本。主要用於版本不兼容時使用。可以為空,為空時用當前版本的Saver。
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進制,為false時,input_graph為文件。默認False
4、input_checkpoint:(必選)檢查點數據文件。訓練時,給Saver用於保存權重、偏置等變量值。這時用於模型恢復變量值。
5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用於指定輸出節點,將沒有在輸出線上的其它節點剔除。
6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。默認:save/restore_all
7、filename_tensor_name:(可選)已棄用。默認:save/Const:0
8、output_graph:(必選)用來保存整合后的模型輸出文件。
9、clear_devices:(可選),默認True。指定是否清除訓練時節點指定的運算設備(如cpu、gpu、tpu。cpu是默認)
10、initializer_nodes:(可選)默認空。權限加載后,可通過此參數來指定需要初始化的節點,用逗號分隔多個節點名字。
11、variable_names_blacklist:(可先)默認空。變量黑名單,用於指定不用恢復值的變量,用逗號分隔多個變量名字。 

if __name__ == '__main__':
    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])

    if not os.path.isfile(tfmodel + '.meta'):
        print(tfmodel)
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16(batch_size=1)
    else:
        raise NotImplementedError

    net.create_architecture(sess, "TEST", 4,
                            tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    # 保存圖
    tf.train.write_graph(sess.graph_def, 'pb/pb_model', 'model.pb')
    # 把圖和參數結構一起
    freeze_graph.freeze_graph('pb/pb_model/model.pb',
                              '',
                              False,
                              tfmodel,
                              'vgg_16/cls_score/BiasAdd,vgg_16/cls_prob,vgg_16/bbox_pred/BiasAdd,vgg_16/rois/PyFunc',
                              'save/restore_all',
                              'save/Const:0',
                              'pb/pb_model/frozen_model.pb',
                              False,
                              "")

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM