中文ALBERT:TF 转成 Pytorch

1、从下面的链接下载需要的中文ALBERT版本:https://github.com/google-research/ALBERT

下载解压后,文件夹的文件包含:

albert_config.json                                          
checkpoint                                 
model.ckpt-best.data-00000-of-00001
model.ckpt-best.index
model.ckpt-best.meta
vocab_chinese.txt

2、转换方法参考:https://blog.csdn.net/roger_royer/article/details/107144599

上面链接中给出了转换用的脚本文件,注意这个脚本文件需要运行在同时安装了TensorFlow和Pytorch的环境中:

(我的mac上直接 pip3 install tensorflow 和 conda install pytorch 即可安装这两个环境,安装的pytorch版本需要跟自己的生产环境的版本一直,目前遇到转换的时候使用pytorch1.6,生产环境是pytorch1.2,在生产环境中加载转换后的模型参数的时候报如下错误的情况:

OSError: Unable to load weights from pytorch checkpoint file for '/home/.../albert_base_chinese' at '/home/.../albert_base_chinese/pytorch_model.bin'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

解决办法就是保证运行下面转换脚本的pytorch版本和生产环境版本一致。)

"""Convert ALBERT checkpoint."""

import argparse
import logging

import torch

from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert

logging.basicConfig(level=logging.INFO)

def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = AlbertConfig.from_json_file(albert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = AlbertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--albert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained ALBERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)

'''
python3 convert_albert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=./model.ckpt-best --albert_config_file=./albert_config.json --pytorch_dump_path=./pytorch_model.bin
'''

3、运行脚本进行转换:

不需要修改ALBERT解压出来的文件名,将2中的脚本放到ALBERT解压出来的目录中,运行命令:

python3 convert_albert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=./model.ckpt-best --albert_config_file=./albert_config.json --pytorch_dump_path=./pytorch_model.bin

运行上述命令,即可正常进行转换。

注意:

上述运行命令的参数 --tf_checkpoint_path=./model.ckpt-best,其不是目录名,也不是文件名,也不需要对ALBERT解压出来的文件进行修改名称,保持这样的设置即可进行正常的转换。

4、如果出现以下错误:

tensorflow.python.framework.errors_impl.DataLossError: Unable to open table file /Users/wxs/AI/NLP/wordvector/fGat_bert/data/albert_base/model.ckpt-best.data-00000-of-00001: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

原因很可能是 参数 --tf_checkpoint_path=./model.ckpt-best 写错了。

5、另外参考:

http://blog.sina.com.cn/s/blog_53dd83fd0102x6ja.html

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值