TensorFlow 模型保存和導入、加載


在TensorFlow中,保存模型與加載模型所用到的是tf.train.Saver()這個類。我們一般的想法就是,保存模型之后,在另外的文件中重新將模型導入,我可以利用模型中的operation和variable來測試新的數據。


什么是TensorFlow中的模型

首先,我們先來理解一下TensorFlow里面的模型是什么。在保存模型后,一般會出現下面四個文件:

這里寫圖片描述

meta graph:保存了TensorFlow的graph。包括all variables,operations,collections等等。這個文件就是上面的.meta文件。

checkpoint files:二進制文件,保存了所有weights,biases,gradient and all the other variables的值。也就是上圖中的.data-00000-of-00001和.index文件。.data文件包含了所有的訓練變量。以前的TensorFlow版本是一個ckpt文件,現在就是這兩個文件了。與此同時,Tensorflow還有一個名為checkpoint的文件,只保存最新檢查點文件的記錄,即最新的保存路徑。


保存一個TensorFlow的模型

在TensorFlow中,如果想保存一個圖(graph)或者所有的參數的值,那么就需要用到tf.train.Saver()這個類。

import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
saver.save(sess, 'my_test_model')

  

 

上面這段代碼最后一句就是保存模型,第二個參數是一個路徑(包含模型的名字)。當然還有其他的形參,我們接下來講:
global_step:給一個數字,用於保存文件時tensorflow幫你命名。主要是說明了迭代多次后保存了。
write_meta_graph:bool型,說明要不要把TensorFlow的圖保存下來。
關於save函數更多的說明請參考:
https://www.tensorflow.org/api_docs/python/tf/train/Saver#save

例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

  


導入一個訓練好的模型

前門講了如何保存一個模型,現在要把模型導出來用了。

訓練好的模型,.meta文件中已經保存了整個graph,我們無需重建,只要導入.meta文件即可。

with tf.Session() as sess:
        new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')#這個函數就是講graph導出來

  

下面用一個例子來說明一下,直接上完整代碼:

第一個文件,訓練模型並保存模型:

#定義模型
X = tf.placeholder(tf.float32,shape = [None,x_dim],name = 'X')
Y = tf.placeholder(tf.float32,shape = [None,1], name = 'Y')
W = tf.Variable(tf.random_normal([x_dim,1]),name='weight')
b = tf.Variable(tf.random_normal([1]),name='bias')
hypothesis = tf.sigmoid(tf.matmul(X,W)+b)
cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1-Y)*tf.log(1-hypothesis))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)

#假如想要保存hypothesis和cost,以便在保存模型后,重新導入模型時可以使用。
tf.add_to_collection('hypothesis',hypothesis)#必須有個名字,即第一個參數
tf.add_to_collection('cost',cost)

mysaver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(50):
    avg_cost, _ = sess.run([cost,train],feed_dict = {X:x_data,Y:y_data})

mysaver.save(sess, '../model/model_LR_test') #保存模型

  

第二個文件,加載模型,並利用訓練好的模型預測:

sess = tf.Session()
#本來我們需要重新像上一個文件那樣重新構建整個graph,但是利用下面這個語句就可以加載整個graph了,方便
new_saver = tf.train.import_meta_graph('../model/model_LR_test.meta')
new_saver.restore(sess,'../model/model_LR_test')#加載模型中各種變量的值,注意這里不用文件的后綴

#對應第一個文件的add_to_collection()函數
hyp = tf.get_collection('hypothesis')[0] #返回值是一個list,我們要的是第一個,這也說明可以有多個變量的名字一樣。

graph = tf.get_default_graph() 
X = graph.get_operation_by_name('X').outputs[0]#為了將placeholder加載出來

pred = sess.run(hyp,feed_dict = {X:x_valid})
print('auc:',auc(y_valid,pred))

是這樣的,使用TensorFlow構建模型的時候,如果一些operation想要在加載模型時用到。那么需要使用add_to_collection()函數來將operation存起來。然后再加載模型后可以調用。當然tensorflow無論怎樣都需要給每個東西一個名字(string型),只有通過名字才可以找到對應的operation。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM