TensorFlow學習筆記(8)--網絡模型的保存和讀取【轉】


轉自:http://blog.csdn.net/lwplwf/article/details/62419087

 

之前的筆記里實現了softmax回歸分類、簡單的含有一個隱層的神經網絡、卷積神經網絡等等,但是這些代碼在訓練完成之后就直接退出了,並沒有將訓練得到的模型保存下來方便下次直接使用。為了讓訓練結果可以復用,需要將訓練好的神經網絡模型持久化,這就是這篇筆記里要寫的東西。

TensorFlow提供了一個非常簡單的API,即tf.train.Saver類來保存和還原一個神經網絡模型。


下面代碼給出了保存TensorFlow模型的方法:

import tensorflow as tf

# 聲明兩個變量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer() # 初始化全部變量 saver = tf.train.Saver() # 聲明tf.train.Saver類用於保存模型 with tf.Session() as sess: sess.run(init_op) print("v1:", sess.run(v1)) # 打印v1、v2的值一會讀取之后對比 print("v2:", sess.run(v2)) saver_path = saver.save(sess, "save/model.ckpt") # 將模型保存到save/model.ckpt文件 print("Model saved in file:", saver_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

這段代碼中,通過saver.save函數將TensorFlow模型保存到了save/model.ckpt文件中,這里代碼中指定路徑為"save/model.ckpt",也就是保存到了當前程序所在文件夾里面的save文件夾中。

TensorFlow模型會保存在后綴為.ckpt的文件中。保存后在save這個文件夾中實際會出現3個文件,因為TensorFlow會將計算圖的結構和圖上參數取值分開保存。

  • model.ckpt.meta文件保存了TensorFlow計算圖的結構,可以理解為神經網絡的網絡結構
  • model.ckpt文件保存了TensorFlow程序中每一個變量的取值
  • checkpoint文件保存了一個目錄下所有的模型文件列表

這里寫圖片描述


下面代碼給出了加載TensorFlow模型的方法:

可以對比一下v1、v2的值是隨機初始化的值還是和之前保存的值是一樣的?

import tensorflow as tf

# 使用和保存模型代碼中一樣的方式來聲明變量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") saver = tf.train.Saver() # 聲明tf.train.Saver類用於保存模型 with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") # 即將固化到硬盤中的Session從保存路徑再讀取出來 print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的進行對比 print("v2:", sess.run(v2)) print("Model Restored")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

運行結果:

v1: [[ 0.76705766 1.82217288]] v2: [[-0.98012197 1.2369734 0.5797025 ] [ 2.50458145 0.81897354 0.07858191]] Model Restored
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

這段加載模型的代碼基本上和保存模型的代碼是一樣的。也是先定義了TensorFlow計算圖上所有的運算,並聲明了一個tf.train.Saver類。兩段唯一的不同是,在加載模型的代碼中沒有運行變量的初始化過程,而是將變量的值通過已經保存的模型加載進來。 
也就是說使用TensorFlow完成了一次模型的保存和讀取的操作。



如果不希望重復定義圖上的運算,也可以直接加載已經持久化的圖:

import tensorflow as tf
# 在下面的代碼中,默認加載了TensorFlow計算圖上定義的全部變量 # 直接加載持久化的圖 saver = tf.train.import_meta_graph("save/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") # 通過張量的名稱來獲取張量 print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

運行程序,輸出:

[[ 0.76705766 1.82217288]]
  • 1
  • 1

有時可能只需要保存或者加載部分變量。 
比如,可能有一個之前訓練好的5層神經網絡模型,但現在想寫一個6層的神經網絡,那么可以將之前5層神經網絡中的參數直接加載到新的模型,而僅僅將最后一層神經網絡重新訓練。

為了保存或者加載部分變量,在聲明tf.train.Saver類時可以提供一個列表來指定需要保存或者加載的變量。比如在加載模型的代碼中使用saver = tf.train.Saver([v1])命令來構建tf.train.Saver類,那么只有變量v1會被加載進來。

…未完待續


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM