这及代码
# -*- coding: utf-8 -*-
# @Time : 2022/6/28 11:42
# @Author : lwb
# @File : convert.py
# @Project : xunfei
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
#
if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# ## Required parameters
# parser.add_argument("--tf_checkpoint_path",
# default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
# type = str,
# help = "Path to the TensorFlow checkpoint path.")
# parser.add_argument("--bert_config_file",
# default = './chinese_L-12_H-768_A-12_improve1/config.json',
# type = str,
# help = "The config json file corresponding to the pre-trained BERT model. \n"
# "This specifies the model architecture.")
# parser.add_argument("--pytorch_dump_path",
# default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
# type = str,
# help = "Path to the output PyTorch model.")
# args = parser.parse_args()
import os
print(os.path.exists("./user_data/model_param/pretrained_model_param/mc_bert/mc_bert_base/bert_config.json"))
# 第一个url是ckpt url,完整文件名后面.data. 不用加 第二个url是 bert_config , 第三个为保存 模型的地址 convert_tf_checkpoint_to_pytorch("./user_data/model_param/pretrained_model_param/mc_bert/mc_bert_base/bert_model.ckpt",
"./user_data/model_param/pretrained_model_param/mc_bert/mc_bert_base/bert_config.json",
"./user_data/model_param/pretrained_model_param/mc_bert/pytorch_model.bin")