BERT預訓練模型有以下幾個:
BERT-Large, Uncased (Whole Word Masking): 24-layer, 1024-hidden, 16-heads, 340M parametersBERT-Large, Cased (Whole Word Masking): 24-layer, 1024-hidden, 16-heads, 340M parametersBERT-Base, Uncased: 12-layer, 768-hidden, 12-heads, 110M parametersBERT-Large, Uncased: 24-layer, 1024-hidden, 16-heads, 340M parametersBERT-Base, Cased: 12-layer, 768-hidden, 12-heads , 110M parametersBERT-Large, Cased: 24-layer, 1024-hidden, 16-heads, 340M parametersBERT-Base, Multilingual Cased (New, recommended): 104 languages, 12-layer, 768-hidden, 12-heads, 110M parametersBERT-Base, Multilingual Uncased (Orig, not recommended):(Not recommended, useMultilingual Casedinstead): 102 languages, 12-layer, 768-hidden, 12-heads, 110M parametersBERT-Base, Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
數據集准備:
數據集(下載)包括訓練集(train.tsv)、驗證集(dev.tsv)和測試集(test.tsv),格式相同,每一行表示一條數據,每條數據格式為【標簽+TAB+內容】
#批量轉換數據格式
def _writeto_tsv(a): fr = open('/home/zwt/Desktop/testbert/caijing/{}.txt'.format(a), 'r') txt = fr.read() txt = txt.replace('\n', '') txt = txt.replace('\u3000', '') txt = txt.replace(' ', '') txt = txt[:128] txt = '財經\t' + txt + '\n' fw.write(txt) fr.close() fw = open('/home/zwt/Desktop/testbert/caijing.tsv','w') for a in range(799401,799440): _writeto_tsv(a) fw.close() #####
def _writeto_tsv(a): fr = open('/home/zwt/Desktop/testbert/yule/{}.txt'.format(a), 'r') txt = fr.read() txt = txt.replace('\n', '') txt = txt.replace('\u3000', '') txt = txt.replace(' ', '') txt = txt[:128] txt = '娛樂\t' + txt + '\n' fw.write(txt) fr.close() fw = open('/home/zwt/Desktop/testbert/yule.tsv','w') for a in range(157340,157379): _writeto_tsv(a) fw.close() #####
def _writeto_tsv(a): fr = open('/home/zwt/Desktop/testbert/keji/{}.txt'.format(a), 'r') txt = fr.read() txt = txt.replace('\n', '') txt = txt.replace('\u3000', '') txt = txt.replace(' ', '') txt = txt[:128] txt = '科技\t' + txt + '\n' fw.write(txt) fr.close() fw = open('/home/zwt/Desktop/testbert/keji.tsv','w') for a in range(482362,482401): _writeto_tsv(a) fw.close()
修改代碼:
run_classifier.py中有DataProcessor基類:
class DataProcessor(object): """Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set."""
raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.gfile.Open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines
在這個基類中定義了一個讀取文件的靜態方法_read_tsv,四個分別獲取訓練集,驗證集,測試集和標簽的方法。接下來我們要定義自己的數據處理的類,我們將我們的類命名ZwtProcessor,繼承於DataProcessor,編寫ZwtProcessor(本例中使用三分類數據,如果需要更多分類,修改labels參數)
class ZwtProcessor(DataProcessor): """Processor for the News data set (GLUE version)."""
def __init__(self): self.labels = ['財經', '娛樂', '科技'] def get_train_examples(self, data_dir): return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): return self._create_examples( self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): return self.labels def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): guid = "%s-%s" % (set_type, i) text_a = tokenization.convert_to_unicode(line[1]) label = tokenization.convert_to_unicode(line[0]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples
注意這里有一個self._read_tsv()方法,規定讀取的數據是使用TAB分割的,如果你的數據集不是這種形式組織的,需要重寫一個讀取數據的方法,更改“_create_examples()”的實現。
在main函數的processors中加入自己的processors
修改前: processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, } 修改后: processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, "zwt": ZwtProcessor, }
至此已經完成准備工作,編寫一個run.sh文件運行即可,內容如下:
#!/usr/bin/bash
python3 /home/zwt/PycharmProjects/test/bert-master/run_classifier.py \ --task_name=zwt \ --do_train=true \ --do_eval=true \ --data_dir=/home/zwt/PycharmProjects/test/zwtBERT/data/ \ --vocab_file=/home/zwt/PycharmProjects/test/data/chinese_L-12_H-768_A-12/vocab.txt \ --bert_config_file=/home/zwt/PycharmProjects/test/data/chinese_L-12_H-768_A-12/bert_config.json \ --init_checkpoint=/home/zwt/PycharmProjects/test/data/chinese_L-12_H-768_A-12/bert_model.ckpt \ --max_seq_length=128 \ --train_batch_size=32 \ --learning_rate=2e-5 \ --num_train_epochs=3.0 \ --output_dir=/home/zwt/PycharmProjects/test/zwtBERT/zwt_output
######參數解釋#######
data_dir:存放數據集的文件夾
bert_config_file:bert中文模型中的bert_config.json文件
task_name:processors中添加的任務名“zbs”
vocab_file:bert中文模型中的vocab.txt文件
output_dir:訓練好的分類器模型的存放文件夾
init_checkpoint:bert中文模型中的bert_model.ckpt.index文件
do_train:是否訓練,設置為“True”
do_eval:是否驗證,設置為“True”
do_predict:是否測試,設置為“False”
max_seq_length:輸入文本序列的最大長度,也就是每個樣本的最大處理長度,多余會去掉,不夠會補齊。最大值512,當顯存不足時,可以適當降低max_seq_length。
train_batch_size: 訓練模型求梯度時,批量處理數據集的大小。值越大,訓練速度越快,內存占用越多。
eval_batch_size: 驗證時,批量處理數據集的大小。同上。
predict_batch_size: 測試時,批量處理數據集的大小。同上。
learning_rate: 反向傳播更新權重時,步長大小。值越大,訓練速度越快。值越小,訓練速度越慢,收斂速度慢,
容易過擬合。遷移學習中,一般設置較小的步長(小於2e-4)
num_train_epochs:所有樣本完全訓練一遍的次數。
warmup_proportion:用於warmup的訓練集的比例。
save_checkpoints_steps:檢查點的保存頻率。
終端輸入/bin/bash zwtBERTrun.sh即可運行
原生bert指標只有loss和accuracy,可自行修改
修改前: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy( labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } 修改后: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy( labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example) precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example) recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, 'eval_auc': auc, 'eval_precision': precision, 'eval_recall': recall, }
https://www.cnblogs.com/jiangxinyang/p/10241243.html
https://www.jiqizhixin.com/articles/2018-12-03
https://cloud.tencent.com/developer/article/1356797
https://blog.csdn.net/xiaosa_kun/article/details/84868475
