將tensorflow版本的.ckpt模型轉成pytorch的.bin模型


用google-research官方的bert源碼(tensorflow版本)對新的法律語料進行微調,迭代次數為100000次,每隔1000次保存一下模型,得到的結果如下:

將最后三個文件取出,改名為bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta

加上之前微調使用過的config.json以及vocab.txt文件,運行如下文件后生成pytorch.bin,之后就可以被pytorch得代碼調用了。

 1 from __future__ import absolute_import
 2 from __future__ import division
 3 from __future__ import print_function
 4 
 5 import argparse
 6 import torch
 7 
 8 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
 9 
10 import logging
11 logging.basicConfig(level=logging.INFO)
12 
13 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
14     # Initialise PyTorch model
15     config = BertConfig.from_json_file(bert_config_file)
16     print("Building PyTorch model from configuration: {}".format(str(config)))
17     model = BertForPreTraining(config)
18 
19     # Load weights from tf checkpoint
20     load_tf_weights_in_bert(model, config, tf_checkpoint_path)
21 
22     # Save pytorch-model
23     print("Save PyTorch model to {}".format(pytorch_dump_path))
24     torch.save(model.state_dict(), pytorch_dump_path)
25 
26 #
27 if __name__ == "__main__":
28     parser = argparse.ArgumentParser()
29     ## Required parameters
30     parser.add_argument("--tf_checkpoint_path",
31                         default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
32                         type = str,
33                         help = "Path to the TensorFlow checkpoint path.")
34     parser.add_argument("--bert_config_file",
35                         default = './chinese_L-12_H-768_A-12_improve1/config.json',
36                         type = str,
37                         help = "The config json file corresponding to the pre-trained BERT model. \n"
38                             "This specifies the model architecture.")
39     parser.add_argument("--pytorch_dump_path",
40                         default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
41                         type = str,
42                         help = "Path to the output PyTorch model.")
43     args = parser.parse_args()
44     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
45                                      args.bert_config_file,
46                                      args.pytorch_dump_path)

Tip:如果不是BERT模型,是BERT模型的變種,比如MobileBERT,DistilBERT等,數據形式可能不匹配,報錯AttributeError: 'BertForPreTraining' object has no attribute 'bias'

此時需要根據transformers庫里的源碼修改convert_tf_checkpoint_to_pytorch函數,以MobileBERT為例

 1 #參考transformers庫里的transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
 2 from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
 3 
 4 
 5 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
 6     # Initialise PyTorch model
 7     config = MobileBertConfig.from_json_file(mobilebert_config_file)
 8     print(f"Building PyTorch model from configuration: {config}")
 9     model = MobileBertForPreTraining(config)
10     # Load weights from tf checkpoint
11     model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
12     # Save pytorch-model
13     print(f"Save PyTorch model to {pytorch_dump_path}")
14     torch.save(model.state_dict(), pytorch_dump_path)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM