tf 模型保存


tf用 tf.train.Saver類來實現神經網絡模型的保存和讀取。無論保存還是讀取,都首先要創建saver對象。

 

用saver對象的save方法保存模型

保存的是所有變量

save(
    sess,
    save_path,
    global_step=None,  
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

保存模型需要session,初始化變量

 

用法示例

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, "Model/model.ckpt", global_step=3)

輸出

1. global_step 放在文件名后面,起個標記作用

2. save方法輸出4個文件

  // checkpoint 里面是一堆路徑,model_checkpoint_path 記錄了最新模型的路徑,all_model_checkpoint_paths 記錄了之前模型的路徑

  // model.ckpt-3.data-00000-of-00001 存放的是模型參數

  // model.ckpt-3.meta 存放的是計算圖

3. 最多只能保存近5次模型,比如我們迭代100次,每次保存一下,最后只留下了最近的5次。

 

用saver對象的restore方法加載模型

加載的是所有變量,以name為准,假如保存的模型中有變量叫 a ,value是2,那么在加載后,即使重新建立變量a,並賦其他value,其value仍然是2

restore(
    sess,
    save_path
)

加載模型需要session,不需要初始化變量

 

用法示例(接前例)

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22")           # Key v22 not found in checkpoint
result = v1 + v2

saver = tf.train.Saver()
#
with tf.Session() as sess:
    saver.restore(sess, "./Model/model.ckpt-3") # 注意此處路徑前添加"./"
    print(sess.run(result)) # [ 3.]

1. 重新給 name為 v2的變量 賦值,其結果仍然是3,說明加載了之前的v2

2. 新建name為 v22 的變量,報錯, 在保存的模型中沒找到v2 。說明尋找變量以name為准,不以變量名為准

 

繼續做如下嘗試

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22")           # Key v22 not found in checkpoint
result = v1 + v3

saver = tf.train.Saver()
#
with tf.Session() as sess:
    # sess.run(tf.global_variables_initializer())                     # Key v22 not found in checkpoint
    saver.restore(sess, "./Model/model.ckpt-3") # 注意此處路徑前添加"./"
    # sess.run(tf.global_variables_initializer())                       # Key v22 not found in checkpoint
    print(sess.run(result)) # [ 3.]

1. 新建name為v22的變量v3,仍然報錯,說明新的變量沒有被接受

2. 在加載模型前初始化v3,仍然報錯,加載模型后初始化v3,仍然報錯,這說明在加載的模型中不接受新的變量。

 

繼續嘗試

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22")           # Key v22 not found in checkpoint
result = v1 + v3

saver = tf.train.Saver()
#
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())                     # Key v22 not found in checkpoint
    print(sess.run(v3))                                             # [7.]
    saver.restore(sess, "./Model/model.ckpt-3") # 注意此處路徑前添加"./"
    sess.run(tf.global_variables_initializer())                       # Key v22 not found in checkpoint
    print(sess.run(result)) # [ 3.]

在加載模型前初始化變量,正確輸出,但在加載后,報錯,證實了我上面的說法,“不接受新的變量”

 

總結:

1. 模型加載加載的是所有變量,以name為准

2. 模型加載后不接受任何新的變量

3. 在加載模型時需要重新定義計算圖上的所有節點,但是變量無需初始化

 

加載計算圖

直接加載計算圖就無需重新定義計算圖上的節點

 

用法示例

saver = tf.train.import_meta_graph("Model/model.ckpt-3.meta")

with tf.Session() as sess:
    saver.restore(sess, "./Model/model.ckpt-3") # 注意路徑寫法
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))     # [3.]
    # print(sess.run(sess.graph.get_tensor_by_name('add:0')))                 # [3.]

 

重命名變量

在加載模型時不接受新的變量,這會造成很多麻煩。

為解決這個問題,加載模型時可以給變量重命名。

 

用法示例

u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2

# 若直接聲明Saver類對象,會報錯變量找不到
# 使用一個字典dict重命名變量即可,{"已保存的變量的名稱name": 重命名變量名}
# 原來名稱name為v1的變量現在加載到變量u1(名稱name為other-v1)中
saver = tf.train.Saver({"v1": u1, "v2": u2})

with tf.Session() as sess:
    saver.restore(sess, "./Model/model.ckpt-3")
    print(sess.run(result)) # [ 3.]

注意重命名格式  老變量的name: 新變量名

 

 

參考資料:

https://blog.csdn.net/marsjhao/article/details/72829635

https://blog.csdn.net/shuzfan/article/details/79197432


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM