import os
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
# chinese_wobert_plus
# 模型位置
path = r"C:\Users\zmmm\Desktop\TSProjects\chinese_wobert_plus_L-12_H-768_A-12"
# 不需要修改代码,直接使用(尽管chinese_wobert_plus的模型后缀不是.ckpt)
tf_checkpoint_path = os.path.join(path, "bert_model.ckpt")
bert_config_file = os.path.join(path, "bert_config.json")
# 保存pytorch文件的位置
pytorch_dump_path = os.path.join(r'C:\Users\zm\Desktop\TSProjects\wobert_chinese_plus_base', "pytorch_model.bin")
convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
pytorch_dump_path)
参考: