kashgari做Bert+BiLSTM+CRF
kashgari:
-
是一個基於tensorflow的做Bert+LSTM模型的庫
-
kashgari基於tensorflow2.0的版本里面沒有CRF,所以使用的版本如下
環境:
- conda create --envs myTestNER python==3.6
- pip insall tensorflow==1.14.0
- pip install kashgari==1.1.5`
數據准備:
train_x = []
train_y = []
with open('data_file_name', encoding='utf-8') as f:
for line in f.readlines():
cur = line.strip().split()
train_x.append(list(data[0]))
train_y.append(list(data[1]))
# train_x = [[seq1], [seq2], [seq3], ...]
# train_y = [[tag1], [tag2], [tag3], ...]
# seq1 = ['我', '愛', '北', '京']
# tag1 = ['O', 'O', 'B-D', 'I-D']
訓練時候需要下載BERT預訓練模型,從這里下載:https://github.com/ymcui/Chinese-BERT-wwm
這里訓練很簡單,超參數調節可以參考: https://github.com/BrikerMan/Kashgari/blob/v2-trunk/docs/tutorial/text-labeling.md
訓練:
import kashgari
from kashgari.embeddings import BERTEmbedding
from kashgari.tasks.labeling import BiLSTM_CRF_Model
bert = BERTEmbedding(model_folder="chinese_roberta_wwm_ext_L-12_H-768_A-12", sequence_length=256, task=kashgari.LABELING)
model = BiLSTM_CRF_Model(bert)
model.fit(train_x, train_y, x_validate=train_x, y_validate=train_y, epochs=2, batch_size=32)
model.save('my_bert_crf.h5')
預測:
import kashgari
model = kashgari.utils.load_model('my_bert_crf.h5')
predict = model.predict([['我', '愛', '北', '京']]) # 二維
print(predict)
# [['O', 'O', 'B-D', 'I-D']]