『TensorFlow』模型保存和載入方法匯總


『TensorFlow』第七彈_保存&載入會話_霸王回馬

一、TensorFlow常規模型加載方法

保存模型

tf.train.Saver()類,.save(sess, ckpt文件目錄)方法

參數名稱 功能說明 默認值
var_list Saver中存儲變量集合 全局變量集合
reshape 加載時是否恢復變量形狀 True
sharded 是否將變量輪循放在所有設備上 True
max_to_keep 保留最近檢查點個數 5
restore_sequentially 是否按順序恢復變量,模型較大時順序恢復內存消耗小 True

 

var_list是字典形式{變量名字符串: 變量符號},相對應的restore也根據同樣形式的字典將ckpt中的字符串對應的變量加載給程序中的符號。

如果Saver給定了字典作為加載方式,則按照字典來,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否則每個變量尋找自己的name屬性在ckpt中的對應值進行加載。

加載模型

當我們基於checkpoint文件(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化

checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型:

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)

 

.meta文件保存了當前圖結構

.data文件保存了當前參數名和值

.index文件保存了輔助索引信息

.data文件可以查詢到參數名和參數值,使用下面的命令可以查詢保存在文件中的全部變量{名:值}對,

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),None,True)

tf.train.import_meta_graph函數給出model.ckpt-n.meta的路徑后會加載圖結構,並返回saver對象

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函數會返回加載默認圖的saver對象,saver對象初始化時可以指定變量映射方式,根據名字映射變量(『TensorFlow』滑動平均)

saver = tf.train.Saver({"v/ExponentialMovingAverage":v})  

saver.restore函數給出model.ckpt-n的路徑后會自動尋找參數名-值文件進行加載

saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

1.不加載圖結構,只加載參數

由於實際上我們參數保存的都是Variable變量的值,所以其他的參數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經確定了,所以我們可以使用tf.Graph().as_default()新建一個默認圖(建議使用上下文環境),利用這個新圖修改和變量無關的參值大小,從而達到目的。

'''
使用原網絡保存的模型加載到自己重新定義的圖上
可以使用python變量名加載模型,也可以使用節點名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

with tf.Graph().as_default() as g:

    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
    y = Net.inference_1(x, N_CLASS=5, train=False)

    with tf.Session() as sess:
        # 程序前面得有 Variable 供 save or restore 才不報錯
        # 否則會提示沒有可保存的變量
        saver = tf.train.Saver()

        ckpt = tf.train.get_checkpoint_state('./model/')
        img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
        img = sess.run(tf.expand_dims(tf.image.resize_images(
            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))

        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess,'./model/model.ckpt-0')
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            res = sess.run(y, feed_dict={x: img})
            print(global_step,sess.run(tf.argmax(res,1)))

  2.加載圖結構和參數

'''
直接使用使用保存好的圖
無需加載python定義的結構,直接使用節點名稱加載模型
由於節點形狀已經定下來了,所以有不便之處,placeholder定義batch后單張傳會報錯
現階段不推薦使用,以后如果理解深入了可能會找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'


ckpt = tf.train.get_checkpoint_state('./model/')                          # 通過檢查點文件鎖定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')   # 載入圖結構,保存在.meta文件中

with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)                        # 載入參數,參數保存在兩個文件中,不過restore會自己尋找

    img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
    img = sess.run(tf.image.resize_images(
        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
    imgs = []
    for i in range(128):
       imgs.append(img)
    print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))

    '''
    img = sess.run(tf.expand_dims(tf.image.resize_images(
        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
    print(img)
    imgs = []
    for i in range(128):
        imgs.append(img)
    print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
                   feed_dict={'Placeholder:0':img}))

注意,在所有兩種方式中都可以通過調用節點名稱使用節點輸出張量,節點.name屬性返回節點名稱。

  3.簡化版本

# 連同圖結構一同加載
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)
            
# 只加載數據,不加載圖結構,可以在新圖中改變batch_size等的值
# 不過需要注意,Saver對象實例化之前需要定義好新的圖結構,否則會報錯
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二進制模型加載方法

這種加載方法一般是對應網上各大公司已經訓練好的網絡模型進行修改的工作

# 新建空白圖
self.graph = tf.Graph()
# 空白圖列為默認圖
with self.graph.as_default():
    # 二進制讀取模型文件
    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
        # 新建GraphDef文件,用於臨時載入模型中的圖 
        graph_def = tf.GraphDef()
        # GraphDef加載模型中的圖
        graph_def.ParseFromString(f.read())
        # 在空白圖中加載GraphDef中的圖
        tf.import_graph_def(graph_def,name='')
        # 在圖中獲取張量需要使用graph.get_tensor_by_name加張量名
        # 這里的張量可以直接用於session的run方法求值了
        # 補充一個基礎知識,形如'conv1'是節點名稱,而'conv1:0'是張量名稱,表示節點的第一個輸出張量
        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]

 『TensorFlow』遷移學習_他山之石,可以攻玉

『cs231n』通過代碼理解風格遷移

上面兩篇都使用了二進制加載模型的方式

三、二進制模型制作

這節是關於tensorflow的Freezing,字面意思是冷凍,可理解為整合合並;整合什么呢,就是將模型文件和權重文件整合合並為一個文件,主要用途是便於發布。

tensorflow在訓練過程中,通常不會將權重數據保存的格式文件里(這里我理解是模型文件),反而是分開保存在一個叫checkpoint的檢查點文件里,當初始化時,再通過模型文件里的變量Op節點來從checkoupoint文件讀取數據並初始化變量。這種模型和權重數據分開保存的情況,使得發布產品時不是那么方便,我們可以將tf的圖和參數文件整合進一個后綴為pb的二進制文件中,由於整合過程回將變量轉化為常量,所以我們在日后讀取模型文件時不能夠進行訓練,僅能向前傳播,而且我們在保存時需要指定節點名稱。

將圖變量轉換為常量的API:tf.graph_util.convert_variables_to_constants

轉換后的graph_def對象轉換為二進制數據(graph_def.SerializeToString())后,寫入pb即可。

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmodel/test_model.ckpt')
    gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), ['add'])
with tf.gfile.GFile('./tmodel/model.pb', 'wb') as f:
    f.write(gd.SerializeToString())

我們可以直接查看gd:

node {
  name: "v1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
        }
        float_val: 1.0
      }
    }
  }
}
……
node {
  name: "add"
  op: "Add"
  input: "v1/read"
  input: "v2/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
}

四、從圖上讀取張量

上面的代碼實際上已經包含了本小節的內容,但是由於從圖上讀取特定的張量是如此的重要,所以我仍然單獨的補充上這部分的內容。

無論如何,想要獲取特定的張量我們必須要有張量的名稱圖的句柄,比如 'import/pool_3/_reshape:0' 這種,有了張量名和圖,索引就很簡單了。

從二進制模型加載張量

第二小節的代碼很好的展示了這種情況

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # 瓶頸層輸出張量名稱
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 輸入層張量名稱
MODEL_DIR = './inception_dec_2015'  # 模型存放文件夾
MODEL_FILE = 'tensorflow_inception_graph.pb'  # 模型名


# 加載模型
# with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),'rb') as f:   # 閱讀器上下文
with open(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:  # 閱讀器上下文
    graph_def = tf.GraphDef()  # 生成圖
    graph_def.ParseFromString(f.read())  # 圖加載模型
# 加載圖上節點張量(按照句柄理解)
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(  # 從圖上讀取張量,同時導入默認圖
    graph_def,
    return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

從當前圖中獲取對應張量

這個就是很普通的情況,從我們當前操作的圖中獲取某個張量,用於feed啦或者用於輸出等操作,API也很簡單,用法如下:

g.get_tensor_by_name('import/pool_3/_reshape:0')

 g表示當前圖句柄,可以簡單的使用 g = tf.get_default_graph() 獲取。

從圖中獲取節點信息

有的時候我們對於模型中的節點並不夠了解,此時我們可以通過圖句柄來查詢圖的構造:

g = tf.get_default_graph()
print(g.as_graph_def().node)

這個操作將返回圖的構造結構。從這里,對比前面的代碼,我們也可以了解到:graph_def 實際就是圖的結構信息存儲形式,我們可以將之還原為圖(二進制模型加載代碼中展示了),也可以從圖中將之提取出來(本部分代碼)。

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM