1、从下面的链接下载需要的中文ALBERT版本:https://github.com/google-research/ALBERT
下载解压后,文件夹的文件包含:
albert_config.json
checkpoint
model.ckpt-best.data-00000-of-00001
model.ckpt-best.index
model.ckpt-best.meta
vocab_chinese.txt
2、转换方法参考:https://blog.csdn.net/roger_royer/article/details/107144599
上面链接中给出了转换用的脚本文件,注意这个脚本文件需要运行在同时安装了TensorFlow和Pytorch的环境中:
(我的mac上直接 pip3 install tensorflow 和 conda install pytorch 即可安装这两个环境,安装的pytorch版本需要跟自己的生产环境的版本一直,目前遇到转换的时候使用pytorch1.6,生产环境是pytorch1.2,在生产环境中加载转换后的模型参数的时候报如下错误的情况:
OSError: Unable to load weights from pytorch checkpoint file for '/home/.../albert_base_chinese' at '/home/.../albert_base_chinese/pytorch_model.bin'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.
解决办法就是保证运行下面转换脚本的pytorch版本和生产环境版本一致。)
"""Convert ALBERT checkpoint."""
import argparse
import logging
import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = AlbertConfig.from_json_file(albert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = AlbertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--albert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
'''
python3 convert_albert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=./model.ckpt-best --albert_config_file=./albert_config.json --pytorch_dump_path=./pytorch_model.bin
'''
3、运行脚本进行转换:
不需要修改ALBERT解压出来的文件名,将2中的脚本放到ALBERT解压出来的目录中,运行命令:
python3 convert_albert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path=./model.ckpt-best --albert_config_file=./albert_config.json --pytorch_dump_path=./pytorch_model.bin
运行上述命令,即可正常进行转换。
注意:
上述运行命令的参数 --tf_checkpoint_path=./model.ckpt-best,其不是目录名,也不是文件名,也不需要对ALBERT解压出来的文件进行修改名称,保持这样的设置即可进行正常的转换。
4、如果出现以下错误:
tensorflow.python.framework.errors_impl.DataLossError: Unable to open table file /Users/wxs/AI/NLP/wordvector/fGat_bert/data/albert_base/model.ckpt-best.data-00000-of-00001: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
原因很可能是 参数 --tf_checkpoint_path=./model.ckpt-best 写错了。
5、另外参考: