BERT預訓練tensorflow模型轉換為pytorch模型


在Bert的預訓練模型中,主流的模型都是以tensorflow的形勢開源的。但是huggingface在Transformers中提供了一份可以轉換的接口(convert_bert_original_tf_checkpoint_to_pytorch.py)。
但是如何在windows的IDE中執行呢?

  • 首先,需要安裝transformers (可以掛國內清華、豆瓣源之類的加速)
pip install transformers
  • 其次,下載tf版本的bert預訓練模型goole的預訓練模型,下載的模型文件解壓后如下:
    image
  • 寫tf2torch.py腳本且放在模型同目錄中,腳本內容如下:
    image
import transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch as con


con.convert_tf_checkpoint_to_pytorch(
    r'.\bert_model.ckpt',
    r'.\bert_config.json',
    r'.\pytorch_bert.bin'
)

convert_tf_checkpoint_to_pytorch中三個參數分別是:bert模型名稱、config文件地址,輸出的pytorch文件保存地址

  • 然后運行tf2torch.py文件得到如下文件,多了一個pytorch_bert.bin文件
    image
  • 最后注意:可以忽略TensorFlow checkpoint(以bert_model.ckpt開頭的三個文件),但是一定要保留配置文件(bert_config.json)和詞匯表文件(vocab.txt),因為PyTorch模型也需要這些文件。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM