記錄:tf.saved_model 模塊的簡單使用(TensorFlow 模型存儲與恢復)


雖然說 TensorFlow 2.0 即將問世,但是有一些模塊的內容卻是不大變化的。其中就有 tf.saved_model 模塊,主要用於模型的存儲和恢復。為了防止學習記錄文件丟失或者蠢笨的腦子直接遺忘掉這部分內容,在此做點簡單的記錄,以便將來查閱。

最近為了一個課程作業,不得已涉及到關於圖像超分辨率恢復的內容,不得不准備隨時存儲訓練的模型,只好再回過頭來瞄一眼 TensorFlow 文檔,真是太痛苦了。

tf.saved_model 模塊下面有很多文件和函數,精力有限,只好選擇於自己有用的東西來看,可能並不全面,望日后補上。

其中最重要的就是該模塊下的一個類:tf.saved_model.builder.SavedModelBuilder

tf.saved_model.builder.SavedModelBuilder:

# 構造函數
.__init__(export_dir)
"""
作用:
  創建一個保存模型的實例對象
參數:
    export_dir: 模型導出路徑,由於 TensorFlow 會在你指定的路徑上創建文件夾和文件,所以指定的路徑最后不需要帶 /,
   例如:export_dir='/home/***/saved_model' 即可,最后不需要加上 /
"""

# 方法
# 1
.add_meta_graph_and_variables(sess, tags, signature_def_map=None, assets_collection=None,
                              clear_devices=False, main_op=None, strip_default_attrs=False, saver=None)
"""
作用: 
  保存會話對象中的 graph 和所有變量,具體描述可參見文檔
參數:
  sess: TensorFlow 會話對象,用於保存元圖和變量
  tags: 用於保存元圖的標記集(如果存在多個圖對象,需要設置保證每個圖標簽不一樣),是一個列表
  signature_def_map: 一個字典,保存模型時傳入的參數,key 可以是字符串,也可以是 tf.saved_model.signature_constants 文件下預定義的變量,
                      值為 signatureDef protobuf(protobuf 是一種結構化的數據存儲格式)
  assets_collection: 略
  clear_devices: 如果需要清除默認圖上的設備信息,則設置為 true
  main_op: 這個參數包括后面一系列與其相關的東西沒有弄明白
  strip_default_attrs: 如果設置為 True,將從 NodeDefs 中刪除默認值屬性
  saver: tf.train.Saver 的一個實例,用於導出元圖並保存變量
"""

# 2
.add_meta_graph()
"""
作用:
  其除了沒有 sess 參數以外,其他參數和 .add_meta_graph_and_variables() 一模一樣
  調用此方法之前必須先調用 .add_meta_graph_and_variables() 方法
"""

# 3
.save(as_text=False)
"""
作用: 
  將內建的 savedModel protobuf 寫入磁盤
"""

除了這個最重要的類以外,tf.saved_model 模塊還提供了一些方便構建 builder 和加載模型的函數方法。

# 1
tf.saved_model.utils.build_tensor_info(tensor)
"""
作用:
    構建 TensorInfo protobuf,根據輸入的 tensor 構建相應的 protobuf,返回的 TensorInfo 中包含輸入 tensor 的 name,shape,dtype 信息
參數:
    tensor: Tensor 或 SparseTensor
"""

# 2
tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None)
"""
作用:
    構建 SignatureDef protobuf,並返回 SignatureDef protobuf
參數:
    inputs: 一個字典,鍵為字符串類型,值為關於 tensor 的信息,也就是上述的 .build_tensor_info() 函數返回的 TensorInfo protobuf
    outputs: 一個字典,同上
    method_name: SignatureDef 名稱
"""

# 3
tf.saved_model.utils.get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None)
"""
作用:
    根據一個 TensorInfo protobuf 解析出一個 tensor
參數:
    tensor_info: 一個 TensorInfo protobuf
    graph: tensor 所存在的 graph,參數為 None 時,使用默認圖
    import_scope: 給 tensor 的 name 加上前綴
"""

# 4
tf.saved_model.loader.load(sess, tags, export_dir, import_scope=None, **saver_kwargs)
"""
作用:
    加載已存儲的模型
參數:
    sess: 用於恢復模型的 tf.Session() 對象
    tags: 用於標識 MetaGraphDef 的標記,應該和存儲模型時使用的此參數完全一致
    export_dir: 模型存儲路徑
    import_scope: 加前綴
"""

除了這些以外,還有一些 TensorFlow 為了方便而預定義的一些變量,這些變量完全可以使用自定義字符串代替,不再贅述。詳情:https://tensorflow.google.cn/api_docs/python/tf/saved_model

如果只看這些內容的話,確實會使人產生巨大的疑惑,下面是具體實踐的例子:

import tensorflow as tf
from tensorflow import saved_model as sm


# 首先定義一個極其簡單的計算圖
X = tf.placeholder(tf.float32, shape=(3, ))
scale = tf.Variable([10, 11, 12], dtype=tf.float32)
y = tf.multiply(X, scale)

# 在會話中運行
with tf.Session() as sess:
    sess.run(tf.initializers.global_variables())
    value = sess.run(y, feed_dict={X: [1., 2., 3.]})
    print(value)
    
    # 准備存儲模型
    path = '/home/×××/tf_model/model_1'
    builder = sm.builder.SavedModelBuilder(path)
    
    # 構建需要在新會話中恢復的變量的 TensorInfo protobuf
    X_TensorInfo = sm.utils.build_tensor_info(X)
    scale_TensorInfo = sm.utils.build_tensor_info(scale)
    y_TensorInfo = sm.utils.build_tensor_info(y)

    # 構建 SignatureDef protobuf
    SignatureDef = sm.signature_def_utils.build_signature_def(
                                inputs={'input_1': X_TensorInfo, 'input_2': scale_TensorInfo},
                                outputs={'output': y_TensorInfo},
                                method_name='what'
    )

    # 將 graph 和變量等信息寫入 MetaGraphDef protobuf
    # 這里的 tags 里面的參數和 signature_def_map 字典里面的鍵都可以是自定義字符串,TensorFlow 為了方便使用,不在新地方將自定義的字符串忘記,可以使用預定義的這些值
    builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING], 
                                         signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef}
  ) 

 # 將 MetaGraphDef 寫入磁盤
    builder.save()

這樣我們就把模型整體存儲到了磁盤中,而且我們將三個變量 X, scale, y 全部序列化后存儲到了其中,所以恢復模型時便可以將他們完全解析出來:

import tensorflow as tf
from tensorflow import saved_model as sm


# 需要建立一個會話對象,將模型恢復到其中
with tf.Session() as sess:
    path = '/home/×××/tf_model/model_1'
    MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path)

    # 解析得到 SignatureDef protobuf
    SignatureDef_d = MetaGraphDef.signature_def
    SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS]

    # 解析得到 3 個變量對應的 TensorInfo protobuf
    X_TensorInfo = SignatureDef.inputs['input_1']
    scale_TensorInfo = SignatureDef.inputs['input_2']
    y_TensorInfo = SignatureDef.outputs['output']

    # 解析得到具體 Tensor
    # .get_tensor_from_tensor_info() 函數中可以不傳入 graph 參數,TensorFlow 自動使用默認圖
    X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph)
    scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph)
    y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph)

    print(sess.run(scale))
    print(sess.run(y, feed_dict={X: [3., 2., 1.]}))

# 輸出
[10. 11. 12.]
[30. 22. 12.]

可以看出模型整體和變量個體都被完整地保存了下來。其中涉及的關於 protobuf 的知識,需要補習,在 TensorFlow 中好多地方都用到了相關的知識。上述恢復模型的代碼中對具體的 TensorInfo protobuf 解析時,還可以使用另一種方式得到相應的 Tensor:

# 已知 X_TensorInfo, scale_TensorInfo, y_TensorInfo
X = sess.graph.get_tensor_by_name(X_TensorInfo.name)
scale = sess.grpah.get_tensor_by_name(scale_TensorInfo.name)
y = sess.graph.get_tensor_by_name(y_TensorInfo.name)

# 因為 TensorFlow 構建 TensorInfo protobuf 時,使用了 Tensor 的 name 信息,所以可以直接讀出來使用

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM