Tensorflow 模型的保存、讀取和凍結、執行


轉載自https://www.jarvis73.cn/2018/04/25/Tensorflow-Model-Save-Read/

本文假設讀者已經懂得了 Tensorflow 的一些基礎概念, 如果不懂, 則移步 TF 官網 .

在 Tensorflow 中我們一般使用 tf.train.Saver() 定義的存儲器對象來保存模型, 並得到形如下面列表的文件:

checkpoint
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

其中 checkpoint 文件中記錄了該儲存器歷史上所有保存過的模型(三件套文件)的名稱, 以及最近一次保存的文件, 這里我們並不需要 checkpoint .

Tensorflow 模型凍結是指把計算圖的定義模型權重合並到同一個文件中, 可以按照以下步驟實施:

  • 恢復已保存的計算圖: 把預先保存的計算圖(meta graph) 載入到默認的計算圖中, 並將計算圖序列化.
  • 加載權重: 開啟一個會話(Session), 把權重載入到計算圖中
  • 刪除推導所需以外的計算圖元數據(metadata): 凍結模型之后是不需要訓練的, 所以只保留推導(inference) 部分的計算圖 (這部分可以通過指定模型輸出來自動完成)
  • 保存到硬盤: 序列化凍結的 graph_def 協議緩沖區(Protobuf) 並轉儲到硬盤

注意: 前兩步實際上就是 Tensorflow 中的加載計算圖和權重, 關鍵的部分就是圖的凍結, 而凍結 TF 已經提供了函數.

1. 模型的保存

TF 使用 saver = tf.train.Saver() 定義一個存儲器對象, 然后使用 saver.save() 函數保存模型. saver 定義時可以指定需要保存的變量列表, 最大的檢查點數量, 是否保存計算圖等. 官網例子如下:

v1 = tf.Variable(..., name='v1') v2 = tf.Variable(..., name='v2') # 使用字典指定要保存的變量, 此時可以為每個變量重命名(保存的名字) saver = tf.train.Saver({'v1': v1, 'v2': v2}) # 使用列表指定要保存的變量, 變量名字不變. 以下兩種保存方式等價 saver = tf.train.Saver([v1, v2]) saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) # 保存相應變量到指定文件, 如果指定 global_step, 則實際保存的名稱變為 model.ckpt-xxxx saver.save(sess, "./model.ckpt", global_step) 

每保存一次, 就會產生前言所述的四個文件, 其中 checkpoint 文件會更新. 其中 saver.save() 函數的 write_meta_graph 參數默認為 True , 即保存權重時同時保存計算圖到 meta 文件.

2. 模型的讀取

TF 模型的讀取分為兩種, 一種是我們僅讀取模型變量, 即 index 文件和 data 文件; 另一種是讀取計算圖. 通常來說如果是我們自己保存的模型, 那么完全可以設置 saver.save() 函數的 write_meta_graph 參數為 False以節省空間和保存的時間, 因為我們可以使用已有的代碼直接重新構建計算圖. 當然如果為了模型遷移到其他地方, 則最好同時保存變量和計算圖.

2.1 讀取計算圖

2.1.1 讀取計算圖核心函數

從 meta 文件讀取計算圖使用 tf.train.import_meta_graph() 函數, 比如:

with tf.Session() as sess: new_saver = tf.train.import_meta_graph("model.ckpt.meta") 

此時計算圖就會加載到 sess 的默認計算圖中, 這樣我們就無需再次使用大量的腳本來定義計算圖了. 實際上使用上面這兩行代碼即可完成計算圖的讀取. 注意可能我們獲取的模型(meta文件)同時包含定義在CPU主機(host)和GPU等設備(device)上的, 上面的代碼保留了原始的設備信息. 此時如果我們想同時加載模型權重, 那么如果當前沒有指定設備的話就會出現錯誤, 因為tensorflow無法按照模型中的定義把某些變量(的值)放在指定的設備上. 那么有一個辦法是增加一個參數清楚設備信息.

with tf.Session() as sess: new_saver = tf.train.import_meta_graph("model.ckpt.meta", clear_devices=True) 

2.1 節剩下的內容我們嘗試探索一下 TF 中圖的一些內容和基本結構, 不感興趣可以跳過直接看 2.2 節.

2.1.2 獲取計算圖內的任意變量/操作

接下來可以使用 get_all_collection_keys() 來獲取該計算圖中所有的收集器的鍵:

sess.graph.get_all_collection_keys() # 或 sess.graph.collections # 或 tf.get_default_graph().get_all_collection_keys() # 輸出 ['summaries', 'train_op', 'trainable_variables', 'variables'] 

進一步我們可以通過 get_collection() 函數來獲取每個收集器的內容:

from pprint import pprint pprint(sess.graph.get_collection("summaries")) pprint(sess.graph.get_collection("variables")) ... 

通過瀏覽 variables , trainable_variables , sumamries 和 train_op 中的變量我們可以初步推斷計算圖的結構和重要信息. 此外, 讀取計算圖后還可以直接使用 tf.summary.FileWriter() 保存計算圖到 tensorboard, 從而獲得更直觀的計算圖.

要注意的是, get_collection() 方法只能獲得保存在收集器中的變量, 而無法看到其他操作(如 placeholder), 除非在腳本中構建計算圖時刻意把某些操作加入到某個 collection . 所以我們可以用更騷的方法來獲取這些沒有包含在 collection 中的操作:

sess.graph.get_operations() # 或 for op in sess.graph.get_operations(): print(op.name, op.values()) 

函數 get_operations() 返回一個列表, 列表的每個元素均為計算圖中的一個 Operation 對象. 舉個栗子, 當我們使用 reshape() 函數時 tf.reshape(x, [-1, 28, 28, -1]) 在計算圖中會產生這樣的計算節點

圖 1: Tensorboard 中操作 tf.reshape(x, shape) 的計算圖

其中 x 就是上圖中左下角的 input , 右側的小柱狀圖表示我對 Reshape 的輸出做了 summary 並命名為 input . Tensorboard 中類似於 shape 這樣的小圓點表示常數(類型仍然是 Operation), 點擊后可以看到該操作的屬性

圖 2: Tensorboard 中常量 shape 的屬性

而屬性中的 tensor_content 的值就是該常數被賦予的值. 實際上我們也可以通過代碼開查看計算圖中操作的屬性:

sess.graph.get_operation_by_name("input_reshape/Reshape/shape").node_def 

通過名稱索引該 reshape 操作, 並獲取其 node_def 屬性即可得到和圖 2 相同的信息. 注意到, shape 的值是一個字符串 "\377\377\377\377\034\000\000\000\034\000\000\000\001\000\000\000" , 該字符串可以這么理解: 沒餓過形如 \377 的單元表示一個字節, 該字節用八進制來表示, 比如 \377 還原為二進制為 011 111 111, 由於我們可以看到該常量的類型為 DT_INT32, 即四個字節, 所以每四個字節拼成一個長整型數字, 即 \377\377\377\377 表示成十六進制為FFFFFFFF , 十進制為 -1; 而 \034\000\000\000 (注意這里是小端表示法, litter endian, 即從后往前讀取字節)表示成十六進制為 1C000000 , 十進制為 28 .

2.2 讀取模型變量

2.2.1 讀取模型變量核心函數

讀取模型權重也很簡單, 仍然使用 tf.train.Saver() 來讀取:

# 首先定義一系列變量 ... # 載入變量的值 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "path/to/model.ckpt") 

注意模型路徑中應當以諸如 .ckpt 之類的來結尾, 即需要保證實際存在的文件是 model.ckpt.data-00000-of-00001 和 model.ckpt.index , 而指定的路徑是 model.ckpt 即可. 類似地, 如果我們只需要載入部分模型變量, 則和保存模型變量類似地可以在 tf.train.Saver() 中使用字典或列表來指定相應的變量. 注意, 載入的模型變量是不需要再初始化的(即不需要 tf.variable_initializer() 初始化), 所以如果只載入部分變量, 則要么手動指定, 要么先初始化所有的變量, 再從檢查點載入變量的值.

2.2.2 獲取任意模型變量的屬性

另外, 我們還可以使用 TF 內置的函數 tf.train.get_checkpoint_state() 來獲得最近的一次檢查點的文件名:

ckpt = tf.train.get_checkpoint_state(log_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) 

有時候我們需要瀏覽變量中變量名/形狀/值, 則可以預先通過下面的代碼進行查看:

from tensorflow.python import pywrap_tensorflow as pt reader = pt.NewCheckpointReader("path/to/model.ckpt") # 獲取 變量名: 形狀 vars = reader.get_variable_to_shape_map() for k in sorted(vars): print(k, vars[k]) # 獲取 變量名: 類型 vars = reader.get_variable_to_dtype_map() for k in sorted(vars): print(k, vars[k]) # 獲取張量的值 value = reader.get_tensor("tensor_name") 

其中 get_variable_to_shape_map() 函數會生成一個 {變量名: 形狀} 的字典, 而 get_variable_to_dtype_map()類似. 而 get_tensor() 函數會返回相應變量名的變量值, 返回一個 numpy 數組.

另一種獲取方法則是 TF 官方文檔給出的使用 tensorflow.python.tools.insepct_checkpoint , 示例代碼如下, 不再贅述:

from tensorflow.python.tools import inspect_checkpoint as chkp # 打印檢查點所有的變量 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True) # tensor_name: v1 # [ 1. 1. 1.] # tensor_name: v2 # [-1. -1. -1. -1. -1.] # 僅打印檢查點中的 v1 變量 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False) # tensor_name: v1 # [ 1. 1. 1.] 

3. 模型的凍結

我們從已有的三個檢查點文件出發生成凍結模型:

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

假設我們已經通過上面模型的讀取知道了我們需要的最終輸出的張量名為 "Accuracy/prediction" 和 "Metric/Dice" , 則按照前言部分的步驟來凍結模型:

import tensorflow as tf # 指定模型輸出, 這樣可以允許自動裁剪無關節點. 這里認為使用逗號分割 output_nodes = ["Accuracy/prediction", "Metric/Dice"] # 1. 加載模型 saver = tf.train.import_meta_graph("model.ckpt.meta", clear_devices=True) with tf.Session(graph=tf.get_default_graph()) as sess: # 序列化模型 input_graph_def = sess.graph.as_graph_def() # 2. 載入權重 saver.restore(sess, "model.ckpt") # 3. 轉換變量為常量 output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_nodes) # 4. 寫入文件 with open("frozen_model.pb", "wb") as f: f.write(output_graph_def.SerializeToString()) 

注意, 我們凍結模型的目的是不再訓練, 而僅僅做正向推導使用, 所以才會把變量轉換為常量后同計算圖結構保存在協議緩沖區文件(.pb)中, 因此需要在計算圖中預先定義輸出節點的名稱.

4. 模型的執行

模型的執行過程也很簡單, 首先從協議緩沖區文件(*.pb)中讀取模型, 然后導入計算圖

# 讀取模型並保存到序列化模型對象中 with open(frozen_graph_path, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) # 導入計算圖 graph = tf.Graph() with graph.as_default(): tf.import_graph_def(graph_def, name="MyGraph") 

之后就是獲取輸入和輸出的張量對象, 注意, 在 TF 的計算圖結構中, 我們只能使用 feed_dict 把數值數組傳入張量 Tensor , 同時也只能獲取張量的值, 而不能給Operation 賦值. 由於我們導入序列化模型到計算圖時給定了 name 參數, 所以導入所有操作都會加上 MyGraph 前綴.

接下來我們獲取輸入和輸出對應的張量:

x_tensor = graph.get_tensor_by_name("MyGraph/input/image-input:0") y_tensor = graph.get_tensor_by_name("MyGraph/input/label-input:0") keep_prob = graph.get_tensor_by_name("MyGraph/dropout/Placeholder:0") y_target_tensor = graph.get_tensor_by_name("MyGraph/accuracy/accuracy:0") 

注意 TF 中的張量名均是 op:num 的形式, 其中的 op 表示產生該張量的操作名(可由 tensor.op.name 獲取), 而冒號后面的數字表示該張量是其對應操作的第幾個輸出, 下面的圖給出了張量和操作名的關系

圖 3: Tensorflow 中張量和相應操作的命名關系

最后我們提取 mnist 的數據, 並執行驗證:

from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("mnist_data", one_hot=True) x_values, y_values = mnist.test.next_batch(10000) with tf.Session(graph=graph) as sess: acc = sess.run(acc_tensor, feed_dict={x_tensor: x_values, y_tensor: y_values, keep_prob: 1.0}) print(acc) # 輸出 0.9665


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM