[阿里DIN] 模型保存,加載和使用


[阿里DIN] 模型保存,加載和使用

0x00 摘要

Deep Interest Network(DIN)是阿里媽媽精准定向檢索及基礎算法團隊在2017年6月提出的。其針對電子商務領域(e-commerce industry)的CTR預估,重點在於充分利用/挖掘用戶歷史行為數據中的信息。

本系列文章會解讀論文以及源碼,順便梳理一些深度學習相關概念和TensorFlow的實現。

本文是系列第 12 篇 :介紹DIN模型的保存,加載和使用。

0x01 TensorFlow模型

1.1 模型文件

TensorFlow模型會保存在checkpoint相關文件中。因為TensorFlow會將計算圖的結構和圖上參數取值分開保存,所以保存后在相關文件夾中會出現3個文件。

下面就是DIN,DIEN相關生成的文件,可以通過名稱來判別。

checkpoint				

ckpt_noshuffDIN3.data-00000-of-00001
ckpt_noshuffDIN3.meta
ckpt_noshuffDIN3.index

ckpt_noshuffDIEN3.data-00000-of-00001	
ckpt_noshuffDIEN3.index			
ckpt_noshuffDIEN3.meta

所以我們可以認為和保存的模型直接相關的是以下這四個文件:

  • checkpoint文件保存了一個目錄下所有的模型文件列表,這個文件是TensorFlow自動生成且自動維護的。在 checkpoint文件中維護了由一個TensorFlow持久化的所有TensorFlow模型文件的文件名。當某個保存的TensorFlow模型文件被刪除時,這個模型所對應的文件名也會從checkpoint文件中刪除。checkpoint中內容的格式為CheckpointState Protocol Buffer.
  • .meta文件 保存了TensorFlow計算圖的結構,可以理解為神經網絡的網絡結構。
    TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的信息以及運行計算圖中節點所需要的元數據。TensorFlow中元圖是由MetaGraphDef Protocol Buffer定義的。MetaGraphDef 中的內容構成了TensorFlow持久化時的第一個文件。保存MetaGraphDef 信息的文件默認以.meta為后綴名。
  • .index文件保存了當前參數名。
  • model.ckpt文件保存了TensorFlow程序中每一個變量的取值,這個文件是通過SSTable格式存儲的,可以大致理解為就是一個(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在這個文件中存儲的變量列表。列表剩下的每一行保存了一個變量的片段,變量片段的信息是通過SavedSlice Protocol Buffer定義的。SavedSlice類型中保存了變量的名稱、當前片段的信息以及變量取值。TensorFlow提供了tf.train.NewCheckpointReader類來查看model.ckpt文件中保存的變量信息。

1.2 freeze_graph

正如前文所述,tensorflow在訓練過程中,通常不會將權重數據保存的格式文件里,反而是分開保存在一個叫checkpoint的檢查點文件里,當初始化時,再通過模型文件里的變量Op節點來從checkoupoint文件讀取數據並初始化變量。這種模型和權重數據分開保存的情況,使得發布產品時不是那么方便,所以便有了freeze_graph.py腳本文件用來將這兩文件整合合並成一個文件。

freeze_graph.py是怎么做的呢?

  • 它先加載模型文件
  • 提供checkpoint文件地址后,它從checkpoint文件讀取權重數據初始化到模型里的權重變量;
  • 將權重變量轉換成權重常量 (因為常量能隨模型一起保存在同一個文件里);
  • 再通過指定的輸出節點沒用於輸出推理的Op節點從圖中剝離掉;
  • 使用tf.train.writegraph保存圖,這個圖會提供給freeze_graph使用;
  • 再使用freeze_graph重新保存到指定的文件里;

0x02 DIN代碼

因為 DIN 源碼中沒有實現此部分,所以我們需要自行添加。

2.1 輸出結點

首先,在model.py中,需要聲明輸出結點。

def build_fcn_net(self, inp, use_dice = False):
    .....
    # 此處需要給 y_hat 添加一個name
    self.y_hat = tf.nn.softmax(dnn3, name='final_output') + 0.00000001

2.2 保存函數

其次,需要添加一個保存函數,調用 freeze_graph 來進行保存。

需要注意幾點:

  • write_graph 的 as_text 參數默認是 True,我們這里設置為 False。有的環境如果設置為 True 會有問題;
  • 因為write_graph 的 as_text 參數做了設置,所以freeze_graph的參數也做相應設置: input_binary=True
  • input_checkpoint 參數需要針對DIN或者DIEN做相應調整;

具體代碼如下:

def din_freeze_graph(sess):
    # 模型持久化,將變量值固定
    output_graph_def = convert_variables_to_constants(
            sess=sess,
            input_graph_def=sess.graph_def, # 等於:sess.graph_def
            output_node_names=['final_output']) # 如果有多個輸出節點,以逗號隔開
    tf.train.write_graph(output_graph_def, 'dnn_best_model', 'model.pb', False)

    freeze_graph.freeze_graph(
            input_graph='./dnn_best_model/model.pb',
            input_saver='',
            input_binary=True,
            input_checkpoint='./dnn_best_model/ckpt_noshuffDIN3',
            output_node_names='final_output', # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
            restore_op_name='save/restore_all',
            filename_tensor_name='save/Const:0',
            output_graph='./dnn_best_model/frozen_model.pb',
            clear_devices=False,
            initializer_nodes=''
            )

2.2 調用保存

我們在train函數中,存儲模型之后,進行調用。

def train(...):
                if (iter % save_iter) == 0:
                    print('save model iter: %d' %(iter))
                    model.save(sess, model_path+"--"+str(iter))
                    freeze_graph(sess) # 此處調用

0x03 驗證

3.1 加載

加載函數如下:

def load_graph(fz_gh_fn):
    with tf.gfile.GFile(fz_gh_fn, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(
                graph_def,
                input_map=None,
                return_elements=None,
                name="prefix"  # 此處可以自己修改
            )
    return graph

調用加載函數如下,我們在加載之后,打印出圖中對應節點:

graph = load_graph('./dnn_best_model/frozen_model.pb')
for op in graph.get_operations():
    print(op.name, op.values())

從打印結果我們可以看出來,有些op是Inputs相關,final_output節點則是我們之前設定的。

(u'prefix/Inputs/mid_his_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
(u'prefix/Inputs/cat_his_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
(u'prefix/Inputs/uid_batch_ph', (<tf.Tensor 'prefix/Inputs/uid_batch_ph:0' shape=(?,) dtype=int32>,))
(u'prefix/Inputs/mid_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_batch_ph:0' shape=(?,) dtype=int32>,))
(u'prefix/Inputs/cat_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_batch_ph:0' shape=(?,) dtype=int32>,))
(u'prefix/Inputs/mask', (<tf.Tensor 'prefix/Inputs/mask:0' shape=(?, ?) dtype=float32>,))
(u'prefix/Inputs/seq_len_ph', (<tf.Tensor 'prefix/Inputs/seq_len_ph:0' shape=(?,) 
                               
......            
                               
(u'prefix/final_output', (<tf.Tensor 'prefix/final_output:0' shape=(?, 2) dtype=float32>,))

3.2 驗證

驗證數據可以自己炮制,或者就是從測試數據中取出兩條即可,我們的驗證文件名字為 local_predict_splitByUser

0	A3BI7R43VUZ1TY	B00JNHU0T2	Literature & Fiction	0989464105B00B01691C14778097321608442845	BooksLiterature & FictionBooksBooks

1	A3BI7R43VUZ1TY	0989464121	Books	0989464105B00B01691C14778097321608442845	BooksLiterature & FictionBooksBooks

驗證代碼如下,其中feed_dict如何填充,需要根據上節的輸出結果來進行相關配置。

def predict(
        graph,
        predict_file = "local_predict_splitByUser",
        uid_voc = "uid_voc.pkl",
        mid_voc = "mid_voc.pkl",
        cat_voc = "cat_voc.pkl",
        batch_size = 128,
        maxlen = 100):
    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph = graph) as sess:
        predict_data = DataIterator(predict_file, uid_voc, mid_voc, cat_voc, batch_size, maxlen)
        for src, tgt in predict_data:
            uids, mids, cats, mid_his, cat_his, mid_mask, target, sl, noclk_mids, noclk_cats = prepare_data(src, tgt, maxlen, return_neg=True)
            final_output = "prefix/final_output:0"
            feed_dict = {
                'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
                'prefix/Inputs/cat_his_batch_ph:0':cat_his,
                'prefix/Inputs/uid_batch_ph:0':uids,
                'prefix/Inputs/mid_batch_ph:0':mids,
                'prefix/Inputs/cat_batch_ph:0':cats,
                'prefix/Inputs/mask:0':mid_mask,
                'prefix/Inputs/seq_len_ph:0':sl
            }
            y_hat = sess.run(final_output, feed_dict = feed_dict)
            print(y_hat)

預測結果如下:

[[0.95820646 0.04179354]
 [0.09431148 0.9056886 ]]

3.3 為什么要在tensor后面加:0

在上節中,我們可以看到在feed_dict之中,給定的tensor名字后面都帶了 :0

feed_dict = {
    'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
    'prefix/Inputs/cat_his_batch_ph:0':cat_his,
    'prefix/Inputs/uid_batch_ph:0':uids,
    'prefix/Inputs/mid_batch_ph:0':mids,
    'prefix/Inputs/cat_batch_ph:0':cats,
    'prefix/Inputs/mask:0':mid_mask,
    'prefix/Inputs/seq_len_ph:0':sl
}

這里需要注意,TensorFlow的運算結果不是一個數,而是一個張量結構。張量的命名形式:“node : src_output”,node為節點的名稱,src_output 表示當前張量來自來自節點的第幾個輸出。

在我們這里,prefix/Inputs/mid_batch_ph 是操作節點,prefix/Inputs/mid_batch_ph:0 才是變量的名字。冒號后面的數字編號表示這個張量是計算節點上的第幾個結果

0xFF 參考

【TensorFlow】freeze_graph

[深度學習] TensorFlow中模型的freeze_graph

TensorFlow模型冷凍以及為什么tensor名字要加:0

tensorflow實戰筆記(19)----使用freeze_graph.py將ckpt轉為pb文件

Tensorflow-GraphDef、MetaGraph、CheckPoint


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM