TensorFlow转torch 模型

这及代码

# -*- coding: utf-8 -*-
# @Time : 2022/6/28 11:42
# @Author : lwb
# @File : convert.py
# @Project : xunfei

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert

import logging
logging.basicConfig(level=logging.INFO)

def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)

#
if __name__ == "__main__":
    # parser = argparse.ArgumentParser()
    # ## Required parameters
    # parser.add_argument("--tf_checkpoint_path",
    #                     default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
    #                     type = str,
    #                     help = "Path to the TensorFlow checkpoint path.")
    # parser.add_argument("--bert_config_file",
    #                     default = './chinese_L-12_H-768_A-12_improve1/config.json',
    #                     type = str,
    #                     help = "The config json file corresponding to the pre-trained BERT model. \n"
    #                         "This specifies the model architecture.")
    # parser.add_argument("--pytorch_dump_path",
    #                     default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
    #                     type = str,
    #                     help = "Path to the output PyTorch model.")
    # args = parser.parse_args()
    import os
    print(os.path.exists("./user_data/model_param/pretrained_model_param/mc_bert/mc_bert_base/bert_config.json"))
    
	# 第一个url是ckpt url,完整文件名后面.data. 不用加  第二个url是 bert_config ,  第三个为保存 模型的地址    convert_tf_checkpoint_to_pytorch("./user_data/model_param/pretrained_model_param/mc_bert/mc_bert_base/bert_model.ckpt",
                                     "./user_data/model_param/pretrained_model_param/mc_bert/mc_bert_base/bert_config.json",
                                     "./user_data/model_param/pretrained_model_param/mc_bert/pytorch_model.bin")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值