背景
BERT是google出的, 理所当然使用tf框架, 但是目前很多项目是使用pytorch框架的, 所以就需要在两个框架之间转换bert模型.
方法
pytorch to tf
主要使用huggingface的转换脚本.但是有几个地方需要修改:
修改包导入:
from transformers import BertModel为from modeling_bert import BertModel
修改L102, load模型的参数:
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir,
)
为:
model = BertModel.from_pretrained(
state_dict=torch.load(args.pytorch_model_path)
)
最后运行脚本:
python convert_bert_pytorch_checkpoint_to_original_tf.py --model_name --pytorch_model_path --tf_cache_dir
其中model_name随便指定一个即可, 没有影响, 不过需要在当前目录下新建model_name目录, 然后把pytorch模型对应的config.json放到该目录下.其他两个参数就是对应的模型了, 没什么好解释的.
tf to pytorch
也是使用huggingface的转换脚本.
如果是bert base