Fairseq的wav2vec2的踩坑之旅4:如何手动将一个Fairseq的wav2vec2模型转换为transformers的模型

本文介绍了如何将Fairseq的wav2vec2模型转换为Transformers模型,详细讨论了转换过程中遇到的问题,如路径处理、配置文件、Omegaconf错误以及模型维度不一致等,并提供了转换和测试模型的步骤。最后,文章提到了在汉语拼音finetune上的尝试。
摘要由CSDN通过智能技术生成

摘要:

本文尝试将用中文拼音预训练的Fairseq的wav2vec2模型转换为transformers模型(以下简写trms),因为汉语拼音的label数量与英文不同,所以本文需要进行模型转换函数的修改。

自己预训练和finetune的模型没有稳定输出,但是应该是label转换的问题
本文可能对“复现党”有一定的参考价值

1.分析transofrmers模型的结构

huggingface下载的模型默认保存在~/.cache/huggingface下面,如果需要离线使用,则需要将其保存到一个常见可见的目录,方便手动管理。

在模型目录下一般包括如下的文件:

  • config.json 模型配置文件,项目配置文件
  • vocab.json 编解码器的字典文件,json格式,字典:key是label,值是id
  • pytorch_model.bin trms转换后的pytorch模型文件
  • special_tokens_map.json 编解码器的特数据字符
  • tokenizer_config.json 编解码器的配置文件

使用fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])尝试导入pytorch_model.bin,报错,分析是从huggingface下载的模型是没有fairseq的task/args/cfg等信息。

    with open(local_path, "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))
        
    #分析类型state是<class 'collections.OrderedDict'>

Tips: 这里有一个方便的下载各个模型的小工具,下载模型到具体目录保存。

from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
import argparse
import os

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="facebook/wav2vec2-base-100h", help="pretrained model name")
    parser.add_argument("--save_dir", type=str, default="openai-gpt", help="pretrained model name")
    args = parser.parse_args()
    print(args)    
    #save model
    save_dir = os.path.expanduser(args.save_dir)    
    # load model and tokenizer
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(args.model_name) 
    tokenizer.save_pretrained(save_dir)
    model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
    model.save_pretrained(save_dir)
    
if __name__=="__main__":
    """
    $prj=~/Documents/projects/transformers/facebook/wav2vec2-base-100h
    python ~/Documents/workspace/fairseq2trms/downloadModel.py --model_name="facebook/wav2vec2-base-100h"  --save_dir=$prj
    """
    main()

2.使用transformers的工具进行导入

**说明:**为什么不直接Fairseq,而是要用transformers呢?

  1. fairseq存在比较严重的过渡封装问题,接口复杂,omgaconf传参工具不容易迁移,不适合作生产环境部署
  2. fairseq做评估和ASR应用需要flashlight,由于防火墙的存在,基本上是无法按照官方教程安装的(vcpkg和编译都不容易)
  3. trms的接口比较直接明确,工具链比较简单

2.1 导入工具参数说明

trms本身提供了从fairseq导入wav2vec2模型的工具:transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py

使用如下脚本可以从自己训练的模型转换trms用的模型:

python -m  transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch  --pytorch_dump_folder_path "~/Documents/projects/transformers/bostenai/960h-zh_CN" --checkpoint_path "~/Documents/projects/Fairseq/ModelSerial/CTC-softlink-slr18_bst-0310/ctc_d1/outputs/checkpoints2/checkpoint_best.pt" --dict_path "~/Documents/projects/transformers/bostenai/960h-zh_CN/dict.ltr.txt" 
#具体参数含义如下
p
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值