以BERT
为代表的预训练模型是目前NLP领域最火热的方向,但是Google发布的 BERT
是Tensorflow
格式的,这让使用pytorch
格式 程序猿
们很为难。
为解决这个问题,本篇以BERT
为例,介绍将Tensorflow
格式的模型转换为Pytorch
格式的模型。
1. 工具安装
使用工具为:Transformers
(链接),该工具对常用的预训练模型进行封装,可以非常方便的使用 pytorch
调用预训练模型。
使用如下命令安装:
pip install transformers
2. 模型转换
- 下载google的
BERT
模型; - 使用如下命令进行转换:
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
transformers bert \
$BERT_BASE_DIR/bert_model.ckpt \
$BERT_BASE_DIR/bert_config.json \
$BERT_BASE_DIR/pytorch_model.bin