摘要:
本文尝试将用中文拼音预训练的Fairseq的wav2vec2模型转换为transformers模型(以下简写trms),因为汉语拼音的label数量与英文不同,所以本文需要进行模型转换函数的修改。
自己预训练和finetune的模型没有稳定输出,但是应该是label转换的问题
本文可能对“复现党”有一定的参考价值
文章目录
1.分析transofrmers模型的结构
huggingface下载的模型默认保存在~/.cache/huggingface下面,如果需要离线使用,则需要将其保存到一个常见可见的目录,方便手动管理。
在模型目录下一般包括如下的文件:
- config.json 模型配置文件,项目配置文件
- vocab.json 编解码器的字典文件,json格式,字典:key是label,值是id
- pytorch_model.bin trms转换后的pytorch模型文件
- special_tokens_map.json 编解码器的特数据字符
- tokenizer_config.json 编解码器的配置文件
使用fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])尝试导入pytorch_model.bin,报错,分析是从huggingface下载的模型是没有fairseq的task/args/cfg等信息。
with open(local_path, "rb") as f:
state = torch.load(f, map_location=torch.device("cpu"))
#分析类型state是<class 'collections.OrderedDict'>
Tips: 这里有一个方便的下载各个模型的小工具,下载模型到具体目录保存。
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
import argparse
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="facebook/wav2vec2-base-100h", help="pretrained model name")
parser.add_argument("--save_dir", type=str, default="openai-gpt", help="pretrained model name")
args = parser.parse_args()
print(args)
#save model
save_dir = os.path.expanduser(args.save_dir)
# load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained(args.model_name)
tokenizer.save_pretrained(save_dir)
model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
model.save_pretrained(save_dir)
if __name__=="__main__":
"""
$prj=~/Documents/projects/transformers/facebook/wav2vec2-base-100h
python ~/Documents/workspace/fairseq2trms/downloadModel.py --model_name="facebook/wav2vec2-base-100h" --save_dir=$prj
"""
main()
2.使用transformers的工具进行导入
**说明:**为什么不直接Fairseq,而是要用transformers呢?
- fairseq存在比较严重的过渡封装问题,接口复杂,omgaconf传参工具不容易迁移,不适合作生产环境部署
- fairseq做评估和ASR应用需要flashlight,由于防火墙的存在,基本上是无法按照官方教程安装的(vcpkg和编译都不容易)
- trms的接口比较直接明确,工具链比较简单
2.1 导入工具参数说明
trms本身提供了从fairseq导入wav2vec2模型的工具:transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
使用如下脚本可以从自己训练的模型转换trms用的模型:
python -m transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch --pytorch_dump_folder_path "~/Documents/projects/transformers/bostenai/960h-zh_CN" --checkpoint_path "~/Documents/projects/Fairseq/ModelSerial/CTC-softlink-slr18_bst-0310/ctc_d1/outputs/checkpoints2/checkpoint_best.pt" --dict_path "~/Documents/projects/transformers/bostenai/960h-zh_CN/dict.ltr.txt"
#具体参数含义如下
p