tensorflow的斷點續訓
2019-09-07
顧名思義,斷點續訓的意思是因為某些原因模型還沒有訓練完成就被中斷,下一次訓練可以在上一次訓練的基礎上繼續訓練而不用從頭開始;這種方式對於你那些訓練時間很長的模型來說非常友好。
如果要進行斷點續訓,那么得滿足兩個條件:
(1)本地保存了模型訓練中的快照;(即斷點數據保存)
(2)可以通過讀取快照恢復模型訓練的現場環境。(斷點數據恢復)
這兩個操作都用到了tensorflow中的train.Saver類。
1.tensorflow.trainn.Saver類
__init__( var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.train.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None )
這里不對所有參數進行介紹,只介紹常用的參數
max_to_keep:允許保存的模型的個數,默認為5;當保存的個數超過5時,自動刪除最舊的模型,以保證最多同時存在5個模型;如果設置為0或者None,則會對所有訓練中的模型進行保存,但是這樣除了多占硬盤外沒什么意義。
其他的參數一般就使用默認值就可以了。
saver = tf.train.Saver(max_to_keep=10)
有機會再補充其他參數的用法。
2.斷點數據的保存
使用saver對象的save方法即可保存模型:
save( sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False )
常用參數:
sess:需要保存的會話,一般就是我們程序中的sess;
save_path:保存模型的文件路徑以及名稱,例如“ckpt/my_model”,注意如果要保存在ckpt文件夾下,那么需要在ckpt后面加個斜杠/;
global_step:訓練次數,saver會自動將這個值加入到保存的文件名字中。
saver.save(sess,"my_model",global_step=1) saver.save(sess,"my_model",global_step=100) saver.save(sess,"ckpt/my_model",global_step=1)
其中1,2,3行代碼分別會:
1:在代碼的路徑下生成名為“my_model_1文件”;
2:在代碼的路徑下生成名為“my_model_100文件”;
3:在ckpt文件夾下生成名為“my_model_1文件”。
最常見的用法:
for epoch in range(n_iter): ''' training process ''' saver.save(sess,ckpt_dir+"model_name",global_step=epoch)
其中ckpt_dir是斷點數據存放的路徑。
3.斷點數據的恢復
3.1 只加載參數,不加載圖
需要先建立一個與之前相同的模型;然后再檢查有沒有斷點數據,如果有,則進行恢復。
''' 模型圖創建 ''' ckpt_dir = "ckpt/" #創建Saver對象 saver = tf.train.Saver() #如果有斷點文件,讀取最近的斷點文件 ckpt = tf.train.latest_checkpoint(ckpt_dir) if ckpt != None: saver.restore(sess,ckpt)
不需要提供模型的名字,tf.train.latest_checkpoint(ckpt_dir)會去ckpt_dir文件夾中自動尋找最新的模型文件。
這個方法要求模型圖建立好之后才允許創建saver,然后進行變量恢復,否則會報錯。
當我們基於checkpoint文件(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化。
3.2 圖結構與參數都加載
不需要自己建立模型圖了,全部靠加載:
import tensorflow as tf #獲取最新斷點數據路徑 ckpt = tf.train.latest_checkpoint("./ckpt/") #加載圖結構 saver = tf.train.import_meta_graph(ckpt+".meta") sess = tf.Session() #加載參數 saver.restore(sess,ckpt) #運行sess sess.run(tf.get_default_graph().get_tensor_by_name("x:0"))
可以通過 tf.get_default_graph().get_tensor_by_name("x:0")獲取模型節點,其中“x:0”是創建節點的時候節點的name。
4.模型文件解析
在程序訓練過程中保存的模型文件如下圖所示:
checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型;
.meta文件保存了當前圖結構
.data文件保存了當前參數名和值
.index文件保存了輔助索引信息
至於文件名后面的數字表示的是模型訓練的不同批次,我們一般只需要最新的那個;由於之前設置最多保存5個模型,所以批次號是從6開始的。
4.1 查看checkpoint
ckpt = tf.train.get_checkpoint_state("./ckpt/") print(ckpt)
結果是文件的斷點狀態信息:
斷點狀態信息下有一個“model_checkpoint_path”屬性,屬性內容是最新的那個模型的路徑,用str類型來表示;
ckpt.model_checkpoint_path
這個與tf.train.latest_checkpoint("./ckpt/")得出的結果是相同的,可以通過這個路徑來加載模型參數。
4.2 通過data文件查看變量名和變量值
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file print_tensors_in_checkpoint_file("./ckpt/model.ckpt-10",None,True)
print_tensors_in_checkpoint_file中輸入的第一個參數即上一節中獲取到的模型路徑;結果會以字典的形式展現出來。

4.3 通過meta文件加載圖結構
saver = tf.train.import_meta_graph('./ckpt/model.ckpt-10.meta')
注意這里的參數是完整的路徑加上meta文件的文件名,后面需要加上“.meta”。
返回的是一個saver對象,這個對象中包含了之前模型的圖結構。