将electra的tf版本转化为pytorch版本
将electra的tf版本转化为pytorch版本
最近参加CAIL比赛需要使用electra,因为不想直接使用最简单的
三行快速加载代码加载模型
。
// 快速加载,通过hugging face的transformers
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-legal-electra-base-discriminator")
model = AutoModel.from_pretrained("hfl/chinese-legal-electra-base-discriminator")
本次尝试使用的方法是下载模型,并在本地进行加载。
后转到electra的github网站寻找方法。(https://github.com/ymcui/Chinese-ELECTRA)
其中刚好发现了有能够和最近参加的比赛相吻合的司法领域版本electra,决定使用这个
但是只有tensoflow的版本,后来决定使用transformers库的convert_electra_original_tf_checkpoint_to_pytorch.py文件来进行tf转换到pytorch
转换方法
在github将tf版本的下载下来后,再从github的config文件夹下下载对应得config文件
convert_electra_original_tf_checkpoint_to_pytorch.py文件76行代码如下:
args = parser.parse_args()
将其修改为如下:
args = parser.parse_args(['--tf_checkpoint_path', './legal_electra_base/legal_electra_base.ckpt',
'--config_file', './legal_electra_base/base_discriminator_config.json',
'--pytorch_dump_path', './output/model.bin',
'--discriminator_or_generator', 'discriminator'])
我的文件结构设置是这样的,读者可相应的做更改
最容易出错得是–tf_checkpoint_path,这个地方我一开始写得是
./legal_electra_base/
然后就一直报错
后来百度了一下接找不到原因,倒是遇到一个老哥转换bert有类似得情况。观察文件后发现有很多文件都包含ckpt
于是尝试将–tf_checkpoint_path换成
./legal_electra_base/legal_electra_base.ckpt
重新运行后转换成功
分析
为什么转化的时候使用的是legal_electra_base.ckpt,查阅资料之后将electra得5个文件的作用做一个说明:
$ tree legal_electra_base/
legal_electra_base/
├── base_discriminator_config.json <- 模型配置文件
├── legal_electra_base.ckpt.data-00000-of-00001 <- 保存断点文件列表,可以用来迅速查找最近一次的断点文件
├── legal_electra_base.ckpt.index <- 为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。
├── legal_electra_base.ckpt.ckpt.meta <- 是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等
└── vocab.txt <- 模型词汇表文件
参考文章
https://blog.csdn.net/sunyueqinghit/article/details/103458365/