近日,在使用Siamese網絡實現西儲大學軸承數據故障診斷中,測試的過程出現了
1 ValueError: The passed save_path is not a valid checkpoint
的錯誤。錯誤是由於在測試的過程中導入checkpoint時,傳入的save_path是無效的,或者是說,傳入的save_path在給定的路徑中沒有找到對應的文件。
網上關於該問題的解決方案主要包含兩個方面:
- checkpoint路徑應該使用相對路徑;
- 路徑字符不要太長
但均沒有從本質上解決遇到的問題,也沒有從源頭講明白bug出現的緣由。
Tensorflow會將模型保存生成四個文件,如下圖所示。
- 圖a的情況是模型保存時,僅傳入了地址,而地址中不包含文件的名稱。
如第6行代碼所示,傳入的save_path中只包含要保存checkpoint的路徑,未聲明保存文件的名稱。
在這種情況下,checkpoint_dir可以直接作為路徑傳入模型恢復save.restore()的函數中。
1 ... 2 TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) 3 checkpoint_dir = os.path.join(".\\checkpoint\\%s" % (TIMESTAMP)) 4 save = tf.train.Saver() 5 ... 6 save_path = save.save(sess,save_path=checkpoint_dir+"\\")
- 圖b的情況是模型保存時,地址中添加了需要保存文件的名稱filename,並且在save聲明時,使用了max_to_keep=1的設置,即保存的文件名稱中,在XXX.ckpt后包含 "-1" 的名稱,其表示當前保存模型的訓練代數。
在這種情況下,使用當前的checkpoint_dir作為模型恢復saver.restore()函數中的路徑,將會報錯。
1 ValueError: The passed save_path is not a valid checkpoint
1 ... 2 TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now()) 3 checkpoint_dir = ".\\checkpoint\\%s" % TIMESTAMP 4 filename = "few_shot_learning_fault_diagnosis" 5 checkpoint_dir = os.path.join(checkpoint_dir, filename+".ckpt") 6 save = tf.train.Saver(max_to_keep=1) 7 ... 8 saver.save(sess, diag_obj.checkpoint_dir, global_step=step)
總結: 在編寫時,如果使用的是save = tf.train.Saver() 使用了max_to_keep=1的設置,並且在模型訓練保存的過程中,是每訓練一代保存一次。 此時,checkpoint_dir將不再適用於save.restore(sess, checkpoint_dir)中的checkpoint_dir。因為從圖b中可以看到,其包含-1(訓練代數的后綴)。如果仍將checkpoint_dir作為模型參數讀入的地址傳入save.restore()中,將會報
1 ValueError: The passed save_path is not a valid checkpoint
的錯誤。
【解決方法:】
使用tf.train.latest_checkpoint()函數,將不包含文件名稱的路徑傳入函數中,獲取到文件的路徑module_file,並將其傳入saver.restore()中,便可以解決上述問題。
1 ... 2 module_file = tf.train.latest_checkpoint(diag_obj.save_path) 3 saver.restore(sess, module_file)
module_file獲取到的結果如下所示, 其包含訓練代數的信息,這也是為什么直接使用原始的checkpoint_dir 會報錯的原因。
1 module_dir: few_shot_learning_fault_diagnosis.ckpt-1
該錯誤的原因是由於每一代保存一次而造成的設置的保存文件名稱與實際保存文件名稱不一致。