tf.train.Saver類的使用
保存模型:
import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2') result=v1+v2 init_op=tf.global_variables_initializer() saver=tf.train.Saver() with tf.Session() as sess: sess.run(init_op) saver.save(sess,'log/model.ckpt')
加載模型:
import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2') result=v1+v2 saver=tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state('log') if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path)
在加載模型時,也是先定義tensorflow計算圖上的所有運算,但不需要運行變量的初始化,因為變量的值可以通過已經保存的模型加載進來。如果不希望重復定義圖上的運算,也可以直接加載已經 持久化的圖。
加載計算圖:
import tensorflow as tf # 直接加載持久化的圖 saver=tf.train.import_meta_graph('log/model.ckpt.meta') with tf.Session() as sess: saver.restore(sess,'log/model.ckpt') # 通過張量的名稱來獲取張量 print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
tf.train.Saver類還支持在保存和加載模型時給變量重命名。
在加載模型時給變量重命名:
import tensorflow as tf # 這里聲明的變量名稱和已經保存的模型中變量的名稱不同。 v1=tf.Variable(tf.constant(1.0,shape=[1]),name='new-v1') v2=tf.Variable(tf.constant(2.0,shape=[1]),name='new-v2') # 直接使用tf.train.Saver()加載模型會提示變量找不到的錯誤# 需要使用一個字典來重命名變量。這個字典指定 # 原來名稱為v1的變量現在加載在變量v1中('new-v1'),名稱為v2的變量加載到 # 變量v2中('new-v2') saver=tf.train.Saver({'v1':v1,'v2':v2}) with tf.Session() as sess: saver.restore(sess, 'log/model.ckpt')
重命名的好處是可以方便使用變量的滑動平均值。使用變量的滑動平均值可以讓神經網絡模型更加健壯。在tensorflow中,每一個變量的滑動平均值是通過影子變量維護的,獲取變量的滑動平均值實際上就是獲取這個影子變量的取值。如果在加載模型時直接將影子變量映射到變量自身,那么在使用訓練好的模型時就不需要再調用函數來獲取變量的滑動平均值了。這樣方便了滑動平均模型的使用。以下代碼給出了一個保存滑動平均模型的樣例。
import tensorflow as tf v=tf.Variable(0,dtype=tf.float32,name='v') # 沒有聲明滑動平均模型時,只有一個變量v,所以下面語句只會輸出'v:0' for variables in tf.global_variables(): print(variables.name) ema=tf.train.ExponentialMovingAverage(0.99) maintain_averages_op=ema.apply(tf.global_variables()) # 在聲明滑動平均模型后,tensorflow會自動生成一個影子變量 # 下面語句會輸出:'v:0'和'v/ExponentialMovingAverage:0' for variables in tf.global_variables(): print(variables.name) saver=tf.train.Saver() with tf.Session() as sess: init_op=tf.global_variables_initializer() sess.run(init_op) sess.run(tf.assign(v,10)) sess.run(maintain_averages_op) # 保存時,tensorflow會將'v:0'和'v/ExponentialMovingAverage:0'兩個變量都保存下來 saver.save(sess,'log/model.ckpt') print(sess.run([v,ema.average(v)])) # 輸出:[10.0, 0.099999905]
基於上面的代碼,通過變量重命名直接讀取變量的滑動平均值。從程序輸出可以看出,讀取的變量v的值實際上是上面代碼中變量v的滑動平均值。通過該方法,就可以使用完全一樣的代碼來計算滑動平均模型前向傳播的結果。
v=tf.Variable(0,dtype=tf.float32,name='v') # 通過變量重命名將原來變量v的滑動平均值直接賦值給v saver=tf.train.Saver({'v/ExponentialMovingAverage':v}) with tf.Session() as sess: saver.restore(sess,'log/model.ckpt') print(sess.run(v)) # 輸出:0.099999905
為了方便加載時重命名滑動平均變量,tf.train.ExpoentialMovingAverage類提供了variables_to_restore函數來生成tf.train.Saver類所需要的變量重命名字典。示例代碼如下:
import tensorflow as tf v=tf.Variable(0,dtype=tf.float32,name='v') ema=tf.train.ExponentialMovingAverage(0.99) # 通過使用variables_to_restore函數來直接生成上面代碼中提供的字典 # {'v/ExponentialMovingAverage':v} # 以下代碼會輸出: # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} print(ema.variables_to_restore()) saver=tf.train.Saver(ema.variables_to_restore()) with tf.Session() as sess: saver.restore(sess,'log/model.ckpt') print(sess.run(v)) # 輸出:0.099999905
tf.train.Saver的缺點就是每次會保存程序的全部信息,但有時並不需要全部信息。比如在測試或離線預測時,只需要知道如何從神經網絡的輸入層經過前向傳播計算得到輸出層即可,而不需要類似於變量初始化、模型保存等輔助結點的信息。而且,將變量取值和計算圖結構分成不同文件存儲有時候也不方便,tensorflow中提供了convert_variables_to_constants函數,可以將計算圖中的變量及其取值通過常量的方式保存,這樣可以將整個計算圖統一存放在一個文件中。示例代碼如下:
import tensorflow as tf from tensorflow.python.framework import graph_util v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2') result=v1+v2 init_op=tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) # 導出當前計算圖的GraphDef部分,只需要這一部分就可以完成從輸入層到輸出層的計算 # 過程 graph_def=tf.get_default_graph().as_graph_def() # 將圖中的變量及其取值轉化為常量,同時將圖中不必要的結點去掉。在下面一行代碼中, # 最后一個參數['add']給出了需要保存的節點名稱。add節點是上面定義的兩個變量相加 # 的操作。注意,'add:0'表示某個計算節點的第一個輸出,是一個張量名。 output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add']) # 將導出的模型存入文件 with tf.gfile.GFile('log/combined_model.pb','wb') as f: f.write(output_graph_def.SerializeToString())
通過下面的代碼可以直接計算定義的加法運算的結果。這種方法可以使用訓練的模型完成遷移學習
import tensorflow as tf with tf.Session() as sess: model_filename='log/combined_model.pb' # 讀取保存的模型文件,將文件解析成對應的GraphDef Protocol Buffer with tf.gfile.FastGFile(model_filename,'rb') as f: graph_def=tf.GraphDef() graph_def.ParseFromString(f.read()) # 將graph_def中保存的圖加載到當前圖中。return_elements=['add:0']給出了返回的 # 張量的名稱。在保存時給出的是計算節點的名稱,所以為'add'。在加載時給出的是 # 張量的名稱,所以是add:0 result=tf.import_graph_def(graph_def,return_elements=['add:0']) print(sess.run(result))