【NLP】使用bert


# 參考 https://blog.csdn.net/luoyexuge/article/details/84939755 小做改動

需要:

  github上下載bert的代碼:https://github.com/google-research/bert

  下載google訓練好的中文語料模型:https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

使用:

  使用bert,其實是使用幾個checkpoint(ckpt)文件。上面下載的zip是google訓練好的bert,我們可以在那個zip內的ckpt文件基礎上繼續訓練,獲得更貼近具體任務的ckpt文件。

 如果是直接使用訓練好的ckpt文件(就是bert模型),只需如下代碼,定義model,獲得model的值

from bert import modeling    
# 使用數據加載BertModel,獲取對應的字embedding model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings ) # 獲取對應的embedding 輸入數據[batch_size, seq_length, embedding_size] embedding = model.get_sequence_output()

 

這里的bert_config 是之前定義的bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file);輸入是input_ids, input_mask, segment_ids三個向量;還有兩個設置is_training(False), use_one_hot_embedding(False),這樣的設置還有很多,這里只列舉這兩個。。

關於FLAGS,需要提到TensorFlow的flags,相當於配置運行變量,設置如下:

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

# 預訓練的中文model路徑和項目路徑
bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'
root_path = '/home/xiangbo_wang/xiangbo/NER/BERT-BiLSTM-CRF-NER'

# 設置bert_config_file
flags.DEFINE_string(
    "bert_config_file", os.path.join(bert_path, 'bert_config.json'),
    "The config json file corresponding to the pre-trained BERT model."
)

 關於輸入的三個向量,具體內容可以參照之前的博客https://www.cnblogs.com/rucwxb/p/10277217.html

input_ids, segment_ids 分別是 token embedding, segment embedding

position embedding會自動生成

input_mask 是input中需要mask的位置,本來是隨機取一部分,這里的做法是把全部輸入位置都mask住。

獲得輸入的這三個向量的方式如下:

 

# 獲得三個向量的函數
def inputs(vectors,maxlen=10):
    length=len(vectors)
    if length>=maxlen:
        return  vectors[0:maxlen],[1]*maxlen,[0]*maxlen
    else:
        input=vectors+[0]*(maxlen-length)
        mask=[1]*length+[0]*(maxlen-length)
        segment=[0]*maxlen
        return input,mask,segment

# 測試的句子
text = request.args.get('text')
vectors = [di.get("[CLS]")] + [di.get(i) if i in di else di.get("[UNK]") for i in list(text)] + [di.get("[SEP]")]

# 轉成1*maxlen的向量
input, mask, segment = inputs(vectors)
input_ids = np.reshape(np.array(input), [1, -1])
input_mask = np.reshape(np.array(mask), [1, -1])
segment_ids = np.reshape(np.array(segment), [1, -1])

 

最后是將變量輸入模型獲得最終的bert向量:

# 定義輸入向量形狀
input_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_ids_p")
input_mask_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_mask_p")
segment_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="segment_ids_p")
 
model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids_p,
        input_mask=input_mask_p,
        token_type_ids=segment_ids_p,
        use_one_hot_embeddings=use_one_hot_embeddings
    )

# 載入預訓練模型
restore_saver = tf.train.Saver()
restore_saver.restore(sess, init_checkpoint)

# 一個[batch_size, seq_length, embedding_size]大小的向量
embedding = tf.squeeze(model.get_sequence_output())
# 運行結果
ret=sess.run(embedding,feed_dict={"input_ids_p:0":input_ids,"input_mask_p:0":input_mask,"segment_ids_p:0":segment_ids})

完整可運行代碼如下:

import tensorflow as tf 
from bert import modeling
import collections
import os
import numpy as np 
import json

flags = tf.flags
FLAGS = flags.FLAGS
bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'

flags.DEFINE_string(
    'bert_config_file', os.path.join(bert_path, 'bert_config.json'),
    'config json file corresponding to the pre-trained BERT model.'
)
flags.DEFINE_string(
    'bert_vocab_file', os.path.join(bert_path,'vocab.txt'),
    'the config vocab file',
)
flags.DEFINE_string(
    'init_checkpoint', os.path.join(bert_path,'bert_model.ckpt'),
    'from a pre-trained BERT get an initial checkpoint',
)
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

def convert2Uni(text):
    if isinstance(text, str):
        return text
    elif isinstance(text, bytes):
        return text.decode('utf-8','ignore')
    else:
        print(type(text))
        print('####################wrong################')


def load_vocab(vocab_file):
    vocab = collections.OrderedDict()
    vocab.setdefault('blank', 2)
    index = 0
    with open(vocab_file) as reader:
    # with tf.gfile.GFile(vocab_file, 'r') as reader:
        while True:
            tmp = reader.readline()
            if not tmp:
                break
            token = convert2Uni(tmp)
            token = token.strip()
            vocab[token] = index 
            index+=1
    return vocab


def inputs(vectors, maxlen = 50):
    length = len(vectors)
    if length > maxlen:
        return vectors[0:maxlen], [1]*maxlen, [0]*maxlen
    else:
        input = vectors+[0]*(maxlen-length)
        mask = [1]*length + [0]*(maxlen-length)
        segment = [0]*maxlen
        return input, mask, segment


def response_request(text):
    vectors = [dictionary.get('[CLS]')] + [dictionary.get(i) if i in dictionary else dictionary.get('[UNK]') for i in list(text)] + [dictionary.get('[SEP]')]
    input, mask, segment = inputs(vectors)

    input_ids = np.reshape(np.array(input), [1, -1])
    input_mask = np.reshape(np.array(mask), [1, -1])
    segment_ids = np.reshape(np.array(segment), [1, -1])

    embedding = tf.squeeze(model.get_sequence_output())
    rst = sess.run(embedding, feed_dict={'input_ids_p:0':input_ids, 'input_mask_p:0':input_mask, 'segment_ids_p:0':segment_ids})

    return json.dumps(rst.tolist(), ensure_ascii=False)


dictionary = load_vocab(FLAGS.bert_vocab_file)
init_checkpoint = FLAGS.init_checkpoint

sess = tf.Session()
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

input_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_ids_p')
input_mask_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_mask_p')
segment_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='segment_ids_p')

model = modeling.BertModel(
    config = bert_config,
    is_training = FLAGS.use_tpu,
    input_ids = input_ids_p,
    input_mask = input_mask_p,
    token_type_ids = segment_ids_p,
    use_one_hot_embeddings = FLAGS.use_tpu,
)
print('####################################')
restore_saver = tf.train.Saver()
restore_saver.restore(sess, init_checkpoint)

print(response_request('我叫水奈樾。'))
View Code

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM