用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:
将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta
加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。
1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4 5 import argparse 6 import torch 7 8 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 9 10 import logging 11 logging.basicConfig(level=logging.INFO) 12 13 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 14 # Initialise PyTorch model 15 config = BertConfig.from_json_file(bert_config_file) 16 print("Building PyTorch model from configuration: {}".format(str(config))) 17 model = BertForPreTraining(config) 18 19 # Load weights from tf checkpoint 20 load_tf_weights_in_bert(model, config, tf_checkpoint_path) 21 22 # Save pytorch-model 23 print("Save PyTorch model to {}".format(pytorch_dump_path)) 24 torch.save(model.state_dict(), pytorch_dump_path) 25 26 # 27 if __name__ == "__main__": 28 parser = argparse.ArgumentParser() 29 ## Required parameters 30 parser.add_argument("--tf_checkpoint_path", 31 default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt', 32 type = str, 33 help = "Path to the TensorFlow checkpoint path.") 34 parser.add_argument("--bert_config_file", 35 default = './chinese_L-12_H-768_A-12_improve1/config.json', 36 type = str, 37 help = "The config json file corresponding to the pre-trained BERT model. \n" 38 "This specifies the model architecture.") 39 parser.add_argument("--pytorch_dump_path", 40 default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin', 41 type = str, 42 help = "Path to the output PyTorch model.") 43 args = parser.parse_args() 44 convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 45 args.bert_config_file, 46 args.pytorch_dump_path)
Tip:如果不是BERT模型,是BERT模型的变种,比如MobileBERT,DistilBERT等,数据形式可能不匹配,报错AttributeError: 'BertForPreTraining' object has no attribute 'bias'
此时需要根据transformers库里的源码修改convert_tf_checkpoint_to_pytorch函数,以MobileBERT为例
1 #参考transformers库里的transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py 2 from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert 3 4 5 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): 6 # Initialise PyTorch model 7 config = MobileBertConfig.from_json_file(mobilebert_config_file) 8 print(f"Building PyTorch model from configuration: {config}") 9 model = MobileBertForPreTraining(config) 10 # Load weights from tf checkpoint 11 model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path) 12 # Save pytorch-model 13 print(f"Save PyTorch model to {pytorch_dump_path}") 14 torch.save(model.state_dict(), pytorch_dump_path)