Tensorflow模型的格式


轉載:https://cloud.tencent.com/developer/article/1009979

tensorflow模型的格式通常支持多種,主要有CheckPoint(*.ckpt)、GraphDef(*.pb)、SavedModel。

 

1. CheckPoint(*.ckpt)

在訓練 TensorFlow 模型時,每迭代若干輪需要保存一次權值到磁盤,稱為“checkpoint”,如下圖所示:

這種格式文件是由 tf.train.Saver() 對象調用 saver.save() 生成的,只包含若干 Variables 對象序列化后的數據,不包含圖結構,所以只給 checkpoint 模型不提供代碼是無法重新構建計算圖的。

載入 checkpoint 時,調用 saver.restore(session, checkpoint_path)。

缺點:首先模型文件是依賴 TensorFlow 的,只能在其框架下使用;其次,在恢復模型之前還需要再定義一遍網絡結構,然后才能把變量的值恢復到網絡中。

 

2. GraphDef(*.pb)

這種格式文件包含 protobuf 對象序列化后的數據,包含了計算圖,可以從中得到所有運算符(operators)的細節,也包含張量(tensors)和 Variables 定義,但不包含 Variable 的值,因此只能從中恢復計算圖,但一些訓練的權值仍需要從 checkpoint 中恢復。下面代碼實現了利用 *.pb 文件構建計算圖:

TensorFlow 一些例程中用到 *.pb 文件作為預訓練模型,這和上面 GraphDef 格式稍有不同,屬於凍結(Frozen)后的 GraphDef 文件,簡稱 FrozenGraphDef 格式。這種文件格式不包含 Variables 節點。將 GraphDef 中所有 Variable 節點轉換為常量(其值從 checkpoint 獲取),就變為 FrozenGraphDef 格式。代碼可以參考 tensorflow/python/tools/freeze_graph.py

*.pb 為二進制文件,實際上 protobuf 也支持文本格式(*.pbtxt),但包含權值時文本格式會占用大量磁盤空間,一般不用。

 

3. SavedModel

https://juejin.im/post/5bbfedd65188255c9b13d964

https://zhuanlan.zhihu.com/p/31417693

這是谷歌推薦的模型保存方式,它具有語言獨立性,可獨立運行,封閉的序列化格式,任何語言都可以解析它,它允許其他語言和深度學習框架讀取、繼續訓練和遷移 TensorFlow 的模型。該格式為 GraphDef 和 CheckPoint 的結合體,另外還有標記模型輸入和輸出參數的 SignatureDef。從 SavedModel 中可以提取 GraphDef 和 CheckPoint 對象。

SavedModel 目錄結構如下:

其中 saved_model.pb(或 saved_model.pbtxt)包含使用 MetaGraphDef protobuf 對象定義的計算圖;assets 包含附加文件;variables 目錄包含 tf.train.Saver() 對象調用 save() API 生成的文件。

以下代碼實現了保存 SavedModel:

 
#創建signature
def signature_def(self):
        inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs),
                  'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs),
                  'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)}

        outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)}

        return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs
                    ,outputs=outputs
                    ,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)


#保存模型
    def save_model(self, sess, signature, save_path):
        builder = tf.saved_model.builder.SavedModelBuilder(save_path)
        builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], {'predict': signature}, clear_devices=True)
        builder.save()

 

載入 SavedModel:

model = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], checkpoint_path)
signature = model.signature_def

char_inputs_ = signature['predict'].inputs['char_inputs'].name
seg_inputs_ = signature['predict'].inputs['seg_inputs'].name

dropout_ = signature['predict'].inputs['dropout'].name
decode_tags_ = signature['predict'].outputs['decode_tags'].name
# get tensor
char_inputs = sess.graph.get_tensor_by_name(char_inputs_)
seg_inputs = sess.graph.get_tensor_by_name(seg_inputs_)
dropout = sess.graph.get_tensor_by_name(dropout_)
decode_tags = sess.graph.get_tensor_by_name(decode_tags_)
decode_tags_ = sess.run([decode_tags], feed_dict={char_inputs: inputs[1], seg_inputs:inputs[2], dropout:1.0 })
更多細節可以參考 tensorflow/python/saved_model/README.md。

4. 各模式之間的轉換

https://zhuanlan.zhihu.com/p/47649285

 

5. 小結

本文總結了 TensorFlow 常見模型格式和載入、保存方法。部署在線服務(Serving)時官方推薦使用 SavedModel 格式,而部署到手機等移動端的模型一般使用 FrozenGraphDef 格式(最近推出的 TensorFlow Lite 也有專門的輕量級模型格式 *.lite,和 FrozenGraphDef 十分類似)。這些格式之間關系密切,可以使用 TensorFlow 提供的 API 來互相轉換。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM