tensorflow的斷點續訓


tensorflow的斷點續訓

2019-09-07

顧名思義,斷點續訓的意思是因為某些原因模型還沒有訓練完成就被中斷,下一次訓練可以在上一次訓練的基礎上繼續訓練而不用從頭開始;這種方式對於你那些訓練時間很長的模型來說非常友好。

如果要進行斷點續訓,那么得滿足兩個條件:

(1)本地保存了模型訓練中的快照;(即斷點數據保存)

(2)可以通過讀取快照恢復模型訓練的現場環境。(斷點數據恢復)

這兩個操作都用到了tensorflow中的train.Saver類。

 

1.tensorflow.trainn.Saver類

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)
這里不對所有參數進行介紹,只介紹常用的參數
max_to_keep:允許保存的模型的個數,默認為5;當保存的個數超過5時,自動刪除最舊的模型,以保證最多同時存在5個模型;如果設置為0或者None,則會對所有訓練中的模型進行保存,但是這樣除了多占硬盤外沒什么意義。
其他的參數一般就使用默認值就可以了。
saver = tf.train.Saver(max_to_keep=10)

有機會再補充其他參數的用法。

2.斷點數據的保存

使用saver對象的save方法即可保存模型:

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False,
    save_debug_info=False
)

常用參數:

sess:需要保存的會話,一般就是我們程序中的sess;

save_path:保存模型的文件路徑以及名稱,例如“ckpt/my_model”,注意如果要保存在ckpt文件夾下,那么需要在ckpt后面加個斜杠/;

global_step:訓練次數,saver會自動將這個值加入到保存的文件名字中。

saver.save(sess,"my_model",global_step=1)
saver.save(sess,"my_model",global_step=100)
saver.save(sess,"ckpt/my_model",global_step=1)

其中1,2,3行代碼分別會:

1:在代碼的路徑下生成名為“my_model_1文件”;

2:在代碼的路徑下生成名為“my_model_100文件”;

3:在ckpt文件夾下生成名為“my_model_1文件”。

 最常見的用法:

for epoch in range(n_iter):
    '''
    training process
    '''
    saver.save(sess,ckpt_dir+"model_name",global_step=epoch)

其中ckpt_dir是斷點數據存放的路徑。

 

3.斷點數據的恢復

3.1 只加載參數,不加載圖

需要先建立一個與之前相同的模型;然后再檢查有沒有斷點數據,如果有,則進行恢復。

'''
模型圖創建
'''
ckpt_dir = "ckpt/"
#創建Saver對象
saver = tf.train.Saver()
#如果有斷點文件,讀取最近的斷點文件
ckpt = tf.train.latest_checkpoint(ckpt_dir)

if ckpt != None:
    saver.restore(sess,ckpt)

不需要提供模型的名字,tf.train.latest_checkpoint(ckpt_dir)會去ckpt_dir文件夾中自動尋找最新的模型文件。

這個方法要求模型圖建立好之后才允許創建saver,然后進行變量恢復,否則會報錯。

當我們基於checkpoint文件(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化。

3.2 圖結構與參數都加載

不需要自己建立模型圖了,全部靠加載:

import tensorflow as tf
#獲取最新斷點數據路徑
ckpt = tf.train.latest_checkpoint("./ckpt/")
#加載圖結構
saver = tf.train.import_meta_graph(ckpt+".meta")

sess = tf.Session()
#加載參數
saver.restore(sess,ckpt)
#運行sess
sess.run(tf.get_default_graph().get_tensor_by_name("x:0"))

 

 可以通過 tf.get_default_graph().get_tensor_by_name("x:0")獲取模型節點,其中“x:0”是創建節點的時候節點的name。

 

4.模型文件解析

在程序訓練過程中保存的模型文件如下圖所示:

 

 checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型;

.meta文件保存了當前圖結構

.data文件保存了當前參數名和值

.index文件保存了輔助索引信息

至於文件名后面的數字表示的是模型訓練的不同批次,我們一般只需要最新的那個;由於之前設置最多保存5個模型,所以批次號是從6開始的。

 

4.1 查看checkpoint

ckpt = tf.train.get_checkpoint_state("./ckpt/")
print(ckpt)

結果是文件的斷點狀態信息:

 斷點狀態信息下有一個“model_checkpoint_path”屬性,屬性內容是最新的那個模型的路徑,用str類型來表示;

ckpt.model_checkpoint_path

 

 這個與tf.train.latest_checkpoint("./ckpt/")得出的結果是相同的,可以通過這個路徑來加載模型參數。

 

4.2 通過data文件查看變量名和變量值

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("./ckpt/model.ckpt-10",None,True)
print_tensors_in_checkpoint_file中輸入的第一個參數即上一節中獲取到的模型路徑;結果會以字典的形式展現出來。

 

 

4.3 通過meta文件加載圖結構

saver = tf.train.import_meta_graph('./ckpt/model.ckpt-10.meta')

注意這里的參數是完整的路徑加上meta文件的文件名,后面需要加上“.meta”。

返回的是一個saver對象,這個對象中包含了之前模型的圖結構。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM