Fairseq的wav2vec2的踩坑之旅4:如何手动将一个Fairseq的wav2vec2模型转换为transformers的模型

摘要:

本文尝试将用中文拼音预训练的Fairseq的wav2vec2模型转换为transformers模型(以下简写trms),因为汉语拼音的label数量与英文不同,所以本文需要进行模型转换函数的修改。

自己预训练和finetune的模型没有稳定输出,但是应该是label转换的问题
本文可能对“复现党”有一定的参考价值

1.分析transofrmers模型的结构

huggingface下载的模型默认保存在~/.cache/huggingface下面,如果需要离线使用,则需要将其保存到一个常见可见的目录,方便手动管理。

在模型目录下一般包括如下的文件:

  • config.json 模型配置文件,项目配置文件
  • vocab.json 编解码器的字典文件,json格式,字典:key是label,值是id
  • pytorch_model.bin trms转换后的pytorch模型文件
  • special_tokens_map.json 编解码器的特数据字符
  • tokenizer_config.json 编解码器的配置文件

使用fairseq.checkpoint_utils.load_model_ensemble_and_task([fname])尝试导入pytorch_model.bin,报错,分析是从huggingface下载的模型是没有fairseq的task/args/cfg等信息。

    with open(local_path, "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))
        
    #分析类型state是<class 'collections.OrderedDict'>

Tips: 这里有一个方便的下载各个模型的小工具,下载模型到具体目录保存。

from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
import argparse
import os

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="facebook/wav2vec2-base-100h", help="pretrained model name")
    parser.add_argument("--save_dir", type=str, default="openai-gpt", help="pretrained model name")
    args = parser.parse_args()
    print(args)    
    #save model
    save_dir = os.path.expanduser(args.save_dir)    
    # load model and tokenizer
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(args.model_name) 
    tokenizer.save_pretrained(save_dir)
    model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
    model.save_pretrained(save_dir)
    
if __name__=="__main__":
    """
    $prj=~/Documents/projects/transformers/facebook/wav2vec2-base-100h
    python ~/Documents/workspace/fairseq2trms/downloadModel.py --model_name="facebook/wav2vec2-base-100h"  --save_dir=$prj
    """
    main()

2.使用transformers的工具进行导入

**说明:**为什么不直接Fairseq,而是要用transformers呢?

  1. fairseq存在比较严重的过渡封装问题,接口复杂,omgaconf传参工具不容易迁移,不适合作生产环境部署
  2. fairseq做评估和ASR应用需要flashlight,由于防火墙的存在,基本上是无法按照官方教程安装的(vcpkg和编译都不容易)
  3. trms的接口比较直接明确,工具链比较简单

2.1 导入工具参数说明

trms本身提供了从fairseq导入wav2vec2模型的工具:transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py

使用如下脚本可以从自己训练的模型转换t

  • 6
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
### 回答1: 以下是基于Hugging Face Transformers库的代码示例,使用wav2vec2模型提取音频特征: ```python import torch import torchaudio from transformers import Wav2Vec2Processor, Wav2Vec2Model # 加载音频文件 audio_file, sr = torchaudio.load("audio_file.wav") # 调整采样率 if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) audio_file = resampler(audio_file) sr = 16000 # 初始化Wav2Vec2模型和处理器 processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # 提取特征 input_values = processor(audio_file, sampling_rate=sr, return_tensors="pt").input_values with torch.no_grad(): features = model(input_values).last_hidden_state ``` 上述代码将加载音频文件,并使用`torchaudio`库将采样率调整为16000。然后,使用Hugging Face Transformers库中的`Wav2Vec2Processor`和`Wav2Vec2Model`类来初始化模型和处理器。最后,使用处理器对音频文件进行编码,并将编码后的张量输入到模型中,以提取音频特征。 ### 回答2: 以下是使用transformers库的wav2vec2模型提取音频特征的代码示例: ```python import torch from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2Model # 读取音频 waveform, sample_rate = torchaudio.load("your_audio_file.wav") # 将音频转换模型接受的输入特征 processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=sample_rate, padding_value=0.0, do_normalize=True) inputs = feature_extractor(waveform, sampling_rate=sample_rate, return_tensors="pt") # 使用wav2vec2模型提取音频特征 model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") outputs = model(inputs.input_values, attention_mask=inputs.attention_mask) # 输出音频特征 features = outputs.last_hidden_state print(features) ``` 在上述代码中,你需要将"your_audio_file.wav"替换为你要读取的音频文件的路径。这段代码将读取该音频文件并将其采样率设置为16000。 然后,将使用`Wav2Vec2FeatureExtractor`将音频转换模型的输入特征。在这里,我们使用的是wav2vec2模型的960小时预训练模型,提供了默认的处理器(processor)和特征提取器(feature_extractor)。 最后,使用`Wav2Vec2Model`加载预训练的wav2vec2模型,并将输入特征传递给模型。输出中的`features`将包含提取的音频特征。 请注意,以上代码需要依赖`transformers`和`torch`库,你可以在运行代码之前使用以下命令进行安装: ``` pip install transformers torch torchaudio ``` ### 回答3: 要使用transformerswav2vec2模型提取音频特征,可以按照以下步骤进行: 1. 安装所需的库和模型:首先需要安装transformers库和torch库,然后下载wav2vec2模型。 ```python !pip install torch !pip install transformers from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") ``` 2. 读取音频文件:将音频文件加载到内存中。 ```python import soundfile as sf audio, _ = sf.read('audio.wav') ``` 3. 预处理音频:对音频进行预处理,包括重采样和归一化。 ```python import torch resampler = torch.nn.Upsample(1, 16000, 16000, mode='linear') audio_resampled = resampler(torch.tensor(audio)).numpy() audio_normalized = audio_resampled / np.max(np.abs(audio_resampled)) ``` 4. 特征提取:使用wav2vec2模型提取音频特征。 ```python input_values = tokenizer(audio_normalized, return_tensors='pt').input_values logits = model(input_values).logits ``` 5. 处理输出结果:根据需要处理特征提取的输出结果。 ```python predicted_ids = torch.argmax(logits, dim=-1) transcription = tokenizer.batch_decode(predicted_ids)[0] ``` 以上就是使用transformerswav2vec2模型提取音频特征的代码。使用这些代码,你可以读取一个采样率为16000的音频,并提取音频特征。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值