官方中文文檔的網址先貼出來:https://tensorflow.google.cn/programmers_guide/saved_model
tf.train.Saver
類別提供了保存和恢復模型的方法。tf.train.Saver
構造函數針對圖中所有變量或指定列表的變量將 save
和 restore
op 添加到圖中。Saver
對象提供了運行這些 op 的方法,指定了寫入或讀取檢查點文件的路徑。
TensorFlow 將變量保存在二進制檢查點文件中,簡略而言,這類文件將變量名稱映射到張量值。
保存變量
使用 tf.train.Saver()
創建 Saver
來管理模型中的所有變量。例如,以下代碼片段展示了如何調用 tf.train.Saver.save
方法以將變量保存到檢查點文件中:
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
恢復變量
tf.train.Saver
對象不僅將變量保存到檢查點文件中,還將恢復變量。請注意,當您恢復變量時,您不必事先將其初始化。例如,以下代碼片段展示了如何調用 tf.train.Saver.restore
方法以從檢查點文件中恢復變量:
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
選擇要保存和恢復的變量(選擇網絡中的部分變量保存或者恢復)
如果您沒有向 tf.train.Saver()
傳遞任何參數,則 Saver 會處理圖中的所有變量。每個變量都保存在創建變量時所傳遞的名稱下。
在檢查點文件中明確指定變量名稱的這種做法有時會非常有用。例如,您可能已經使用名為"weights"
的變量訓練了一個模型,而您想要將該變量的值恢復到名為"params"
的變量中。
有時候,僅保存或恢復模型使用的變量子集也會很有裨益。例如,您可能已經訓練了一個五層的神經網絡,現在您想要訓練一個六層的新模型,並重用該五層的現有權重。您可以使用 Saver 只恢復這前五層的權重。
您可以向 tf.train.Saver()
的構造函數傳遞以下任一內容來輕松指定要保存或加載的名稱和變量:
- 變量列表(將以其本身的名稱保存)。
- Python 字典,其中,鍵是要使用的名稱,鍵值是要管理的變量。
繼續前面所示的保存/恢復示例:
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
# Initialize v1 since the saver will not.
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
注意:
-
如果需要保存和恢復模型變量的不同子集,您可以根據需要創建任意數量的
Saver
對象。同一個變量可以列在多個 Saver 對象中,變量的值只有在Saver.restore()
方法運行時才會更改。 -
如果您僅在會話開始時恢復模型變量的子集,則必須為其他變量運行初始化 op。有關詳情,請參閱
tf.variables_initializer
。 -
要檢查某個檢查點的變量,您可以使用
inspect_checkpoint
庫,尤其是print_tensors_in_checkpoint_file
函數。 -
默認情況下,
Saver
會為每個變量使用tf.Variable.name
屬性的值。但是,當您創建一個Saver
對象時,您可以選擇為檢查點文件中的變量選擇名稱(此為可選操作)。
檢查某個檢查點的變量
可以使用 inspect_checkpoint
庫快速檢查某個檢查點的變量。
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False)
# tensor_name: v1
# [ 1. 1. 1.]
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False)
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
保存和恢復模型概述
如果您想保存和加載變量、圖,以及圖的元數據 - 簡而言之,如果您想保存或恢復模型 - 我們推薦使用 SavedModel。SavedModel 是一種與語言無關,可恢復的密封式序列化格式。SavedModel 可讓較高級別的系統和工具創建、使用和變換 TensorFlow 模型。TensorFlow 提供了多種與 SavedModel 交互的機制,如 tf.saved_model API、Estimator API 和 CLI。
用於構建和加載 SavedModel 的 API