bert,albert的快速訓練和預測


  隨着預訓練模型越來越成熟,預訓練模型也會更多的在業務中使用,本文提供了bert和albert的快速訓練和部署,實際上目前的預訓練模型在用起來時都大致相同。

  基於不久前發布的中文數據集chineseGLUE,將所有任務分成四大類:文本分類,句子對判斷,實體識別,閱讀理解。同類可以共享代碼,除上面四個任務之外,還加了一個learning to rank ,基於pair wise的方式的任務,代碼見:https://github.com/jiangxinyang227/bert-for-task

  具體使用見readme

  模型定義在每個項目下的model.py文件中,直接調用bert和albert的源碼modeling.py將預訓練模型引入,將預訓練模型作為encoder部分,也可以只作為embedding層,再自己定義encoder部分,總之可以非常方便的接入下游任務網絡層,尤其是當你只想使用預訓練模型作為embedding層時,我們需要自己些encoder部分。

     bert_config = modeling.BertConfig.from_json_file(self.__bert_config_path)

        model = modeling.BertModel(config=bert_config,
                                   is_training=self.__is_training,
                                   input_ids=self.input_ids,
                                   input_mask=self.input_masks,
                                   token_type_ids=self.segment_ids,
                                   use_one_hot_embeddings=False)
        output_layer = model.get_pooled_output()

        hidden_size = output_layer.shape[-1].value
        if self.__is_training:
            # I.e., 0.1 dropout
            output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

        with tf.name_scope("output"):
            output_weights = tf.get_variable(
                "output_weights", [self.__num_classes, hidden_size],
                initializer=tf.truncated_normal_initializer(stddev=0.02))

            output_bias = tf.get_variable(
                "output_bias", [self.__num_classes], initializer=tf.zeros_initializer())

            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            self.predictions = tf.argmax(logits, axis=-1, name="predictions")

  在訓練時加載預訓練的參數值來初始化預訓練模型的變量,具體在trainer.py文件中

tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
                tvars, self.__bert_checkpoint_path)
print("init bert model params")
tf.train.init_from_checkpoint(self.
__bert_checkpoint_path, assignment_map) print("init bert model params done") sess.run(tf.variables_initializer(tf.global_variables()))

  在預測時可以直接實例化predict.py文件中的Predictor類就會加載checkpoint模型文件,調用類中的predict方法就可以進行預測,在不需要考慮模型代碼加密,模型優化等情況下,可以直接線上部署。

import json

from predict import Predictor


with open("config/tnews_config.json", "r") as fr:
    config = json.load(fr)


predictor = Predictor(config)
text = "殲20座艙蓋上的兩條“花紋”是什么?"
res = predictor.predict(text)
print(res)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM