bert文本分類模型保存為savedmodel方式


默認bert是ckpt,在進行后期優化和部署時,savedmodel方式更加友好寫。

train完成后,調用如下函數:

def save_savedmodel(estimator, serving_dir, seq_length, is_tpu_estimator):
    feature_map = {
        "input_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_ids'),
        "input_mask": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_mask'),
        "segment_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids'),
        "label_ids": tf.placeholder(tf.int32, shape=[None], name='label_ids'),
    }
    serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map)
    estimator.export_savedmodel(serving_dir,
                                serving_input_receiver_fn,
                                strip_default_attrs=True)
    print("保存savedmodel")

estimator:estimator = Estimator(model_fn=model_fn,params={},config=run_config)

serving_dir:存儲目錄

seq_length:樣本長度

is_tpu_estimator: tpu標志位

 
 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM