TF的模型文件
標簽(空格分隔): TensorFlow
Saver
tensorflow模型保存函數為:
tf.train.Saver()
當然,除了上面最簡單的保存方式,也可以指定保存的步數,多長時間保存一次,磁盤上最多保有幾個模型(將前面的刪除以保持固定個數),如下:
創建saver時指定參數:
saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
其中:
- savable_variables指定待保存的變量,比如指定為tf.global_variables()保存所有global變量;指定為[v1, v2]保存v1和v2兩個變量;如果省略,則保存所有;
- max_to_keep指定磁盤上最多保有幾個模型;
- keep_checkpoint_every_n_hours指定多少小時保存一次。
保存模型時指定參數:
saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)
如上,其中可以指定模型文件名,步數,write_meta_graph則用來指定是否保存meta文件記錄graph等等。
示例:
import tensorflow as tf
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver.save(sess,"checkpoint/model.ckpt",global_step=1)
運行后,保存模型保存,得到四個文件:
- checkpoint
- model.ckpt-1.data-00000-of-00001
- model.ckpt-1.index
- model.ckpt-1.meta
checkpoint中記錄了已存儲(部分)和最近存儲的模型:
model_checkpoint_path: "model.ckpt-1"
all_model_checkpoint_paths: "model.ckpt-1"
...
meta file保存了graph結構,包括 GraphDef,SaverDef等,當存在meta file,我們可以不在文件中定義模型,也可以運行,而如果沒有meta file,我們需要定義好模型,再加載data file,得到變量值。
index file為一個string-string table,table的key值為tensor名,value為serialized BundleEntryProto。每個BundleEntryProto表述了tensor的metadata,比如那個data文件包含tensor、文件中的偏移量、一些輔助數據等。
data file保存了模型的所有變量的值,TensorBundle集合。
Restore
Restore模型的過程可以分為兩個部分,首先是創建模型,可以手動創建,也可以從meta文件里加載graph進行創建。
模型加載為:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/xx/model.ckpt.meta')
saver.restore(sess, "/xx/model.ckpt")
.meta文件中保存了圖的結構信息,因此需要在導入checkpoint之前導入它。否則,程序不知道checkpoint中的變量對應的變量。另外也可以:
# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Now load the checkpoint variable values
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/xx/model.ckpt")
#saver.restore(sess, tf.train.latest_checkpoint('./'))
PS:不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix... instead of any physical pathname.
當然,還有一點需要注意,並非所有的TensorFlow模型都能將graph輸出到meta文件中或者從meta文件中加載進來,如果模型有部分不能序列化的部分,則此種方法可能會無效。
使用Restore的模型
查看模型的參數
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
tvs = [v for v in tf.trainable_variables()]
for v in tvs:
print(v.name)
print(sess.run(v))
如名所言,以上是查看模型中的trainable variables;或者我們也可以查看模型中的所有tensor或者operations,如下:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
gv = [v for v in tf.global_variables()]
for v in gv:
print(v.name)
上面通過global_variables()獲得的與前trainable_variables類似,只是多了一些非trainable的變量,比如定義時指定為trainable=False的變量,或Optimizer相關的變量。
下面則可以獲得幾乎所有的operations相關的tensor:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
ops = [o for o in sess.graph.get_operations()]
for o in ops:
print(o.name)
首先,上面的sess.graph.get_operations()可以換為tf.get_default_graph().get_operations(),二者區別無非是graph明確的時候可以直接使用前者,否則需要使用后者。
此種方法獲得的tensor比較齊全,可以從中一窺模型全貌。不過,最方便的方法還是推薦使用tensorboard來查看,當然這需要你提前將sess.graph輸出。
直接使用原始模型進行訓練或測試
這種操作比較簡單,無非是找到原始模型的輸入、輸出即可。
只要搞清楚輸入輸出的tensor名字,即可直接使用TensorFlow中graph的get_tensor_by_name函數,建立輸入輸出的tensor:
with tf.get_default_graph() as graph:
data = graph.get_tensor_by_name('data:0')
output = graph.get_tensor_by_name('output:0')
從模型中找到了輸入輸出之后,即可直接使用其繼續train整個模型,或者將輸入數據feed到模型里,並前傳得到test輸出了。
需要說明的是,有時候從一個graph里找到輸入和輸出tensor的名字並不容易,所以,在定義graph時,最好能給相應的tensor取上一個明顯的名字,比如:
data = tf.placeholder(tf.float32, shape=shape, name='input_data')
preds = tf.nn.softmax(logits, name='output')
諸如此類。這樣,就可以直接使用tf.get_tensor_by_name(‘input_data:0’)之類的來找到輸入輸出了。
擴展原始模型
除了直接使用原始模型,還可以在原始模型上進行擴展,比如對1中的output繼續進行處理,添加新的操作,可以完成對原始模型的擴展,如:
with tf.get_default_graph() as graph:
data = graph.get_tensor_by_name('data:0')
output = graph.get_tensor_by_name('output:0')
logits = tf.nn.softmax(output)
使用原始模型的某部分
有時候,我們有對某模型的一部分進行fine-tune的需求,比如使用一個VGG的前面提取特征的部分,而微調其全連層,或者將其全連層更換為使用convolution來完成,等等。TensorFlow也提供了這種支持,可以使用TensorFlow的stop_gradient函數,將模型的一部分進行凍結。
with tf.get_default_graph() as graph:
graph.get_tensor_by_name('fc1:0')
fc1 = tf.stop_gradient(fc1)
# add new procedure on fc1