本文記錄使用BERT預訓練模型,修改最頂層softmax層,微調幾個epoch,進行文本分類任務。
BERT源碼
首先BERT源碼來自谷歌官方tensorflow版:https://github.com/google-research/bert
注意,這是tensorflow 1.x 版本的。
BERT預訓練模型
預訓練模型采用哈工大訊飛聯合實驗室推出的WWM(Whole Word Masking)全詞覆蓋預訓練模型,主要考量是BERT對於中文模型來說,是按照字符進行切割,但是注意到BERT隨機mask掉15%的詞,這里是完全隨機的,對於中文來說,很有可能一個詞的某些字被mask掉了,比如說讓我預測這樣一句話:
原話: ”我今天早上去打羽毛球了,然后又去蒸了桑拿,感覺身心愉悅“
MASK:”我[MASK]天早上去打[MASK]毛球了,然后[MASK]去蒸了[MASK]拿,感覺身心[MASK]悅“
雖然說從統計學意義上來講這樣做依然可以學得其特征,但這樣實際上破壞了中文特有的詞結構,那么全詞覆蓋主要就是針對這個問題,提出一種機制保證在MASK的時候要么整個詞都不MASK,要么MASK掉整個詞。
WWM MASK:”我今天早上去打[MASK][MASK][MASK]了,然后又去蒸了[MASK][MASK],感覺身心愉悅“
例子可能舉得不是很恰當,但大概是這個意思,可以參考這篇文章:
https://www.jiqizhixin.com/articles/2019-06-21-01
修改源碼
首先看到下下來的項目結構:
可以看到run_classifier.py文件,這個是我們需要用的。另外,chinese開頭的文件是我們的模型地址,data文件是我們的數據地址,這個每個人可以自己設置。
在run_classifier.py文件中,有一個基類DataProcessor類,這個是我們需要繼承並重寫的:
class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.gfile.Open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines
可以看到我們需要實現獲得訓練、驗證、測試數據接口,以及獲得標簽的接口。
這里我自己用的一個類。注釋比較詳細,就不解釋了,主要體現了只要能獲得數據,不論我們的文件格式是什么樣的,都可以,所以不需要專門為了這個項目去改自己的輸入數據格式。
class StatutesProcessor(DataProcessor): def _read_txt_(self, data_dir, x_file_name, y_file_name): # 定義我們的讀取方式,我的工程中已經將x文本和y文本分別存入txt文件中,沒有分隔符 # 用gfile讀取,打開一個沒有線程鎖的的文件IO Wrapper # 基本上和python原生的open是一樣的,只是在某些方面更高效一點 with tf.gfile.Open(data_dir + x_file_name, 'r') as f: lines_x = [x.strip() for x in f.readlines()] with tf.gfile.Open(data_dir + y_file_name, 'r') as f: lines_y = [x.strip() for x in f.readlines()] return lines_x, lines_y def get_train_examples(self, data_dir): lines_x, lines_y = self._read_txt_(data_dir, 'train_x.txt', 'train_y.txt') examples = [] for (i, line) in enumerate(zip(lines_x, lines_y)): guid = 'train-%d' % i # 規范輸入編碼 text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) # 這里有一些特殊的任務,一般任務直接用上面的就行,下面的label操作可以注釋掉 # 這里因為y會有多個標簽,這里按單標簽來做 label = label.strip().split()[0] # 這里不做匹配任務,text_b為None examples.append( InputExample(guid=guid, text_a=text_a, label=label) ) return examples def get_dev_examples(self, data_dir): lines_x, lines_y = self._read_txt_(data_dir, 'val_x.txt', 'val_y.txt') examples = [] for (i, line) in enumerate(zip(lines_x, lines_y)): guid = 'train-%d' % i # 規范輸入編碼 text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) label = label.strip().split()[0] # 這里不做匹配任務,text_b為None examples.append( InputExample(guid=guid, text_a=text_a, label=label) ) return examples def get_test_examples(self, data_dir): lines_x, lines_y = self._read_txt_(data_dir, 'test_x.txt', 'test_y.txt') examples = [] for (i, line) in enumerate(zip(lines_x, lines_y)): guid = 'train-%d' % i # 規范輸入編碼 text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) label = label.strip().split()[0] # 這里不做匹配任務,text_b為None examples.append( InputExample(guid=guid, text_a=text_a, label=label) ) return examples def get_labels(self): # 我事先統計了所有出現的y值,放在了vocab_y.txt里 # 因為這里沒有原生的接口,這里暫時這么做了,只要保證能讀到所有的類別就行了 with tf.gfile.Open('data/statutes_small/vocab_y.txt', 'r') as f: vocab_y = [x.strip() for x in f.readlines()] return vocab_y
寫好了之后需要更新一下processors列表,在main函數中,最下面一條就是我新加的。
執行訓練微調
python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_train=true --do_eval=true --init_checkpoint=chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=200 --train_batch_size=16 --learning_rate=5e-5 --num_train_epoch=3
相信我,寫在一行,這個會有很多小問題,在centos服務器上如果不能按上返回上一條命令,將會很痛苦。。具體參數含義就和參數名是一致的,不需要解釋。
另外,可以稍稍修改一些東西來動態輸入訓練集上的loss,因為BERT源碼封裝的太高了,所以只能按照這篇文章:https://www.cnblogs.com/jiangxinyang/p/10241243.html里面講的方法,每100個step輸出一次train loss(就是100個batch),這樣做雖然意義不大,但是可以看在你的數據集上模型是不是在收斂,方便調整學習率。
在測試集上進行測試
默認test_batch_size = 8
python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_predict=true --max_seq_length=200
需要注意的是,調用測試接口會在out路徑中生成一個test_results.tsv,這是一個以’\t’為分隔符的文件,記錄了每一條輸入測試樣例,輸出的每一個維度的值(維度數就是類別數目),需要手動做一點操作來得到最終分類結果,以及計算指標等等。
# 計算測試結果 # 因為原生的predict生成一個test_results.tsv文件,給出了每一個sample的每一個維度的值 # 卻並沒有給出具體的類別預測以及指標,這里再對這個“中間結果手動轉化一下” def cal_accuracy(rst_file_dir, y_test_dir): rst_contents = pd.read_csv(rst_file_dir, sep='\t', header=None) # value_list: ndarray value_list = rst_contents.values pred = value_list.argmax(axis=1) labels = [] # 這一步是獲取y標簽到id,id到標簽的對應dict,每個人獲取的方式應該不一致 y2id, id2y = get_y_to_id(vocab_y_dir='../data/statutes_small/vocab_y.txt') with open(y_test_dir, 'r', encoding='utf-8') as f: line = f.readline() while line: # 這里因為y有多個標簽,我要取第一個標簽,所以要單獨做操作 label = line.strip().split()[0] labels.append(y2id[label]) line = f.readline() labels = np.asarray(labels) # 預測,pred,真實標簽,labels accuracy = metrics.accuracy_score(y_true=labels, y_pred=pred) # 這里只舉例了accuracy,其他的指標也類似計算 print(accuracy) def get_y_to_id(vocab_y_dir): # 這里把所有的y標簽值存在了文件中 y_vocab = open(vocab_y_dir, 'r', encoding='utf-8').read().splitlines() y2idx = {token: idx for idx, token in enumerate(y_vocab)} idx2y = {idx: token for idx, token in enumerate(y_vocab)} return y2idx, idx2y
這部分代碼在classifier/cal_test_matrix.py中。
我的代碼地址:
參考:
https://github.com/google-research/bert
https://www.cnblogs.com/jiangxinyang/p/10241243.html
https://www.jiqizhixin.com/articles/2019-06-21-01
https://arxiv.org/abs/1906.08101