0.前言
写在...终于从实习公司离职回来自己搞毕设的研二下,这段时间觉得很自由,但是又很怀念在上海的日子,那段让人难熬但又让人怀念的日子。
1.下载Tensorflow版Bert
Google-research在Github上开源了很多不同形状大小的Bert预训练模型,按需下载。
这里我们以uncased_L-2_H-128_A-2为例,下载好后,你就会得到这个模型相应的配置文件
在tensorflow的框架中,一般是如此加载该预训练模型。
from keras_bert import load_trained_model_from_checkpoint
config_path = './bert-base-uncased/bert_config.json'
checkpoint_path = './bert-base-uncased/bert_model.ckpt'
dict_path = './uncased_L-2_H-128_A-2/vocab.txt'
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
但,如果我想用transformer的库对其进行加载就无法利用已经下载好的模型参数,hugging face上又没有那么多版本大小的Bert(至少我没找到,我只看到了一些base),所以,需要对Tensorflow版的模型参数进行转化成Pytorch版的,然后基于transformer库对其进行加载。
2. Tensorflow to Pytorch
这里主要是对config文件和checkpoint文件进行转化保存成bin类型的文件。
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
# 设置 TensorFlow 模型文件路径
tf_checkpoint_path = "./uncased_L-2_H-256_A-4/bert_model.ckpt"
# 设置 PyTorch 模型配置
config = BertConfig.from_json_file("./uncased_L-2_H-256_A-4/bert_config.json")
model = BertForPreTraining(config)
# 加载 TensorFlow 模型参数
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# 保存 PyTorch 模型
torch.save(model.state_dict(), "./uncased_L-2_H-256_A-4/pytorch_model.bin")
最后,你就会看见你的模型文件夹里多了一个bin文件。
ok!大功告成!