TensorFlow——訓練模型的保存和載入的方法介紹


我們在訓練好模型的時候,通常是要將模型進行保存的,以便於下次能夠直接的將訓練好的模型進行載入。

1.保存模型

首先需要建立一個saver,然后在session中通過saver的save即可將模型保存起來,具體的代碼流程如下

# 前面的是定義好的模型結構

# 前面的代碼是模型的定義代碼

saver = tf.train.Saver()    # 生成saver
 
with tf.Session() as sess:
    sess.run(init)          # 模型的初始化
    # 
    # 模型的訓練代碼,當模型訓練完畢后,下面就可以對模型進行保存了
    # 
    saver.save(sess, "model/linear")     # 當路徑不存在時,會自動創建路徑

2.載入模型

將模型保存后,在保存的路徑中,可以看到生成的模型路徑,下面我們就能夠加載模型了:

saver = tf.train.Saver()

with tf.Session() as sess:
    # 可以對模型進行初始化,也可以不進行模型的初始化,因為后面的加載會覆蓋之前的
    # 初始化操作
    sess.run(init)

    saver.restore(sess, "model/linear")

下面我們以linearmodel為例進行講解:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver()


if __name__ == '__main__':
    with tf.Session() as sess:
        sess.run(init)
        if os.path.exists("model/"):
            saver.restore(sess, "model/linear")

            w_, b_ = sess.run([w, b])

            print(" Finished ")
            print("W: ", w_, " b: ", b_)
            plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
            plt.grid(True)
            plt.show()
        else:
            loss_list = []
            for epoch in range(training_epochs):
                for (x, y) in zip(train_x, train_y):
                    sess.run(optimizer, feed_dict={X: x, Y: y})

                if epoch % display_step == 0:
                    loss = sess.run(cost, feed_dict={X: x, Y: y})
                    loss_list.append(loss)
                    print('Iter: ', epoch, ' Loss: ', loss)

            w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

            saver.save(sess, "model/linear")

            print(" Finished ")
            print("W: ", w_, " b: ", b_, " loss: ", loss)
            plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
            plt.grid(True)
            plt.show()

3.查看模型的內容

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
modeldir = 'model/'
print_tensors_in_checkpoint_file(modeldir + 'linear.cpkt', None, True)

在上述使用saver的代碼中,我們還可以將參數放入Saver中實現指定存儲參數的功能,可以指定存儲變量名字和變量的對應關系,如下形式:

saver = tf.train.Saver({'weight_':w, 'bias_':b})
# saver = tf.train.Saver([w, b])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM