bert微調步驟:
首先從主函數開刀:
copy run_classifier.py 隨便重命名 my_classifier.py
先看主函數:
if __name__ == "__main__": flags.mark_flag_as_required("data_dir") flags.mark_flag_as_required("task_name") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run()
1,data_dir
flags.mark_flag_as_required("data_dir")中data_dir為數據的路徑文件夾,數據格式已經定義好了:
class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, label=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.label = label
要求的數據格式是:必選參數:guid, text_a,可選參數text_b, label
其中單句子分類任務不需要text_b,且在test數據樣本中不需要輸入label
2,task_name
processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, }
其中task_name表示processors這個字典中的鍵值對,在bert中給了四個,分別是:"cola","mnli","mrpc","xnli",如果需要別的,另行添加
值得注意的是:
task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels()
task_name是用來選擇processor的,在bert的源碼中有4個processors,而我們進行微調,需要自定義自己的processor,如下:
class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version).""" def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] #todo def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, i) text_a = tokenization.convert_to_unicode(line[3]) text_b = tokenization.convert_to_unicode(line[4]) if set_type == "test": label = "0" else: label = tokenization.convert_to_unicode(line[0]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples
其實processor表示對數據進行處理的類,它繼承了DataProcessor類對輸入數據進行預處理,此外,在data_dir文件夾中,我們的文件格式為.tsv格式,由於設定的分類為二分類,我們將label設置為了0,1
同時_create_examples()中,給定了如何獲取guid以及如何給text_a, text_b和label賦值。
主函數的前兩句代碼看完了,繼續看主函數
if __name__ == "__main__": flags.mark_flag_as_required("data_dir") flags.mark_flag_as_required("task_name") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run()
3,vocab_file, bert_config_file, output_dir
其中,vocab_file, bert_config_file分別是下載預訓練模型的文件,output_dir表示輸出的微調之后的model
此外,在前面所說的.tsv文件格式類似於.csv文件
train.tsv和dev.tsv文件格式
標簽+“/t”(制表符)+句子
test文件為
句子
4,修改processors字典,添加自己的分類
processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, "mrpc": MrpcProcessor
}
5,設定參數,進行fine-tune
python my_classifier.py \ --task_name=mprc \ --do_train=true \ --do_eval=true \ --data_dir=$GLUE_DIR/MRPC \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ --train_batch_size=32 \ --learning_rate=2e-5 \ --num_train_epochs=3.0 \ --output_dir=/tmp/mrpc_output/