【深度学习】Tensorflow转Pytorch(BERT配置文件)

0.前言

        写在...终于从实习公司离职回来自己搞毕设的研二下,这段时间觉得很自由,但是又很怀念在上海的日子,那段让人难熬但又让人怀念的日子。

1.下载Tensorflow版Bert

        Google-research在Github上开源了很多不同形状大小的Bert预训练模型,按需下载。

        这里我们以uncased_L-2_H-128_A-2为例,下载好后,你就会得到这个模型相应的配置文件

        在tensorflow的框架中,一般是如此加载该预训练模型。

from keras_bert import load_trained_model_from_checkpoint

config_path = './bert-base-uncased/bert_config.json'
checkpoint_path = './bert-base-uncased/bert_model.ckpt'
dict_path = './uncased_L-2_H-128_A-2/vocab.txt'

bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)

        但,如果我想用transformer的库对其进行加载就无法利用已经下载好的模型参数,hugging face上又没有那么多版本大小的Bert(至少我没找到,我只看到了一些base),所以,需要对Tensorflow版的模型参数进行转化成Pytorch版的,然后基于transformer库对其进行加载。 

   2. Tensorflow to Pytorch

        这里主要是对config文件和checkpoint文件进行转化保存成bin类型的文件。

import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert

# 设置 TensorFlow 模型文件路径
tf_checkpoint_path = "./uncased_L-2_H-256_A-4/bert_model.ckpt"

# 设置 PyTorch 模型配置
config = BertConfig.from_json_file("./uncased_L-2_H-256_A-4/bert_config.json")
model = BertForPreTraining(config)

# 加载 TensorFlow 模型参数
load_tf_weights_in_bert(model, config, tf_checkpoint_path)

# 保存 PyTorch 模型
torch.save(model.state_dict(), "./uncased_L-2_H-256_A-4/pytorch_model.bin")

         最后,你就会看见你的模型文件夹里多了一个bin文件。

        ok!大功告成! 

  • 8
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值