Distil-Whisper模型训练全流程指南

Distil-Whisper模型训练全流程指南

distil-whisper Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate. distil-whisper 项目地址: https://gitcode.com/gh_mirrors/di/distil-whisper

前言

Distil-Whisper是基于OpenAI Whisper模型的蒸馏版本,通过知识蒸馏技术将大型Whisper模型压缩为更小、更高效的版本,同时保持较高的语音识别准确率。本文将详细介绍使用PyTorch框架训练Distil-Whisper模型的全流程,包括环境准备、数据预处理、模型初始化、训练和评估等关键步骤。

环境准备

硬件要求

训练Distil-Whisper模型建议使用配备高性能GPU的机器,如NVIDIA A100或V100显卡。显存容量至少需要16GB,推荐使用24GB或以上的显卡以获得更好的训练效率。

软件依赖

训练过程需要以下主要软件包:

  • PyTorch:深度学习框架
  • Transformers:提供Whisper模型实现
  • Datasets:数据处理库
  • Accelerate:分布式训练工具

安装命令如下:

pip install torch transformers datasets accelerate

环境验证

安装完成后,可以通过以下代码验证环境是否配置正确:

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio

# 加载模型和处理器
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

# 加载示例音频数据
common_voice = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="validation", streaming=True)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

# 执行推理
inputs = processor(next(iter(common_voice))["audio"]["array"], sampling_rate=16000, return_tensors="pt")
generated_ids = model.generate(inputs.input_features)
pred_text = processor.decode(generated_ids[0], skip_special_tokens=True)

print("预测结果:", pred_text)

训练流程概述

完整的Distil-Whisper训练包含四个主要阶段:

  1. 伪标签生成:使用大型Whisper模型为训练数据生成转录文本
  2. 学生模型初始化:从教师模型中提取关键层构建学生模型
  3. 模型训练:执行知识蒸馏训练
  4. 性能评估:评估蒸馏后模型的识别准确率

下面将详细介绍每个阶段的具体操作。

1. 伪标签生成

伪标签生成是知识蒸馏的第一步,使用大型Whisper模型为训练数据集生成转录文本。这些转录将作为学生模型训练的目标。

关键参数说明

  • model_name_or_path: 指定使用的Whisper模型版本
  • dataset_name: 使用的语音数据集
  • output_dir: 伪标签输出目录
  • concatenate_audio: 是否将音频拼接为30秒片段(推荐启用)
  • language: 指定转录语言(对非英语模型很重要)

示例命令

以下是为印地语(hi)Common Voice数据集生成伪标签的示例:

accelerate launch run_pseudo_labelling.py \
  --model_name_or_path "openai/whisper-large-v3" \
  --dataset_name "mozilla-foundation/common_voice_16_1" \
  --dataset_config_name "hi" \
  --output_dir "./common_voice_16_1_hi_pseudo_labelled" \
  --per_device_eval_batch_size 64 \
  --dtype "bfloat16" \
  --language "hi" \
  --task "transcribe" \
  --concatenate_audio \
  --preprocessing_batch_size 500

音频拼接技术

启用concatenate_audio参数会将短音频片段拼接为接近30秒的长片段,这带来两个主要优势:

  1. 减少训练时的padding,提高计算效率
  2. 使模型更好地学习长序列依赖关系

preprocessing_batch_size控制拼接操作的批大小,较大的值可以提高拼接效率但会增加内存消耗。

2. 学生模型初始化

学生模型是从教师模型中提取关键层构建的较小模型。初始化策略对最终模型性能有重要影响。

层选择策略

Distil-Whisper采用最大化间距策略选择学生模型的层:

  • 编码器:默认保留全部32层(可调整)
  • 解码器:从教师模型的32层中选择间距最大的层

例如,当指定2层解码器时,会选择第1层和第32层。

初始化示例

python create_student_model.py \
  --teacher_checkpoint "openai/whisper-large-v3" \
  --encoder_layers 32 \
  --decoder_layers 2 \
  --save_dir "./distil-large-v3-init"

语言迁移技巧

可以通过指定已蒸馏的多语言模型作为教师模型来利用语言迁移:

--teacher_checkpoint "distil-whisper/distil-large-v3"

3. 模型训练

训练阶段使用伪标签数据和教师模型对学生模型进行知识蒸馏。

损失函数

训练使用复合损失函数:

  • 交叉熵损失:学生输出与伪标签的差异
  • KL散度损失:学生与教师输出分布的差异

关键训练参数

  • learning_rate: 学习率(建议2e-5到5e-5)
  • warmup_steps: 学习率预热步数
  • max_steps: 最大训练步数
  • gradient_checkpointing: 梯度检查点(节省显存)

训练示例

accelerate launch run_distillation.py \
  --model_name_or_path "./distil-large-v3-init" \
  --teacher_model_name_or_path "openai/whisper-large-v3" \
  --train_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
  --train_split "train+validation" \
  --eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
  --eval_split "test" \
  --output_dir "./distil-large-v3-hi" \
  --per_device_train_batch_size 32 \
  --per_device_eval_batch_size 16 \
  --dtype "bfloat16" \
  --learning_rate 3e-5 \
  --lr_scheduler_type "linear" \
  --warmup_steps 100 \
  --max_steps 5000 \
  --gradient_checkpointing \
  --push_to_hub

多数据集训练

为提高模型鲁棒性,可以组合多个数据集进行训练:

--train_dataset_name "dataset1+dataset2" \
--train_split "train+train" \

4. 模型评估

训练完成后,需要评估模型在测试集上的表现,主要指标为词错误率(WER)。

评估脚本关键参数

  • model_name_or_path: 要评估的模型路径
  • dataset_name: 评估数据集
  • metric: 评估指标(通常为"wer")
  • language: 目标语言

评估示例

python eval_whisper.py \
  --model_name_or_path "./distil-large-v3-hi" \
  --dataset_name "mozilla-foundation/common_voice_16_1" \
  --dataset_config_name "hi" \
  --split "test" \
  --metric "wer" \
  --language "hi"

训练技巧与建议

  1. 数据量建议:至少使用1000小时数据以获得良好性能
  2. 多语言训练:组合多种语言数据可提高模型鲁棒性
  3. 超参数调优:尝试不同学习率和训练步数组合
  4. 混合精度训练:使用bfloat16或float16加速训练
  5. 梯度累积:在小批量情况下模拟大批量训练

常见问题解决

  1. 显存不足:减小批大小,启用梯度检查点
  2. 训练不稳定:降低学习率,增加预热步数
  3. 过拟合:增加数据量,使用早停策略
  4. 性能不理想:检查伪标签质量,调整模型结构

结语

本文详细介绍了Distil-Whisper模型的完整训练流程。通过合理配置各阶段参数,开发者可以训练出适用于特定语言和场景的高效语音识别模型。蒸馏技术能够在保持较高准确率的同时显著减小模型规模,使其更适合资源受限的应用场景。

distil-whisper Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate. distil-whisper 项目地址: https://gitcode.com/gh_mirrors/di/distil-whisper

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

### 使用 Whisper 语音识别模型的实现与应用 #### 安装依赖库 为了使用 `openai-whisper` 及其支持的各种模型,包括 Distil-Whisper,在开始之前需安装必要的 Python 库。通常情况下,这可以通过 pip 来完成。 ```bash pip install git+https://github.com/openai/whisper.git ``` #### 加载预训练模型 加载特定版本的模型非常简单。对于希望使用的 Distil-Whisper 模型而言,代码如下所示: ```python import whisper model = whisper.load_model("distil-medium") # 或者 "distil-large" ``` 此处 `"distil-medium"` 和 `"distil-large"` 是两种不同大小的 Distil-Whisper 模型名称[^1]。 #### 执行音频转文字任务 一旦选择了合适的模型并成功加载之后,就可以调用该模型来处理实际的任务——将输入的声音文件转换成相应的文本描述。下面是一个简单的例子说明如何做到这一点: ```python audio_file_path = "./example_audio.mp3" result = model.transcribe(audio_file_path) print(result["text"]) ``` 这段程序会读取指定路径下的 MP3 文件作为输入源,并输出由模型推测出来的对应的文字内容。 #### 处理多语言环境中的音频数据 值得注意的是,OpenAI 的 Whisper 不仅限于英语,还能够很好地适应其他多种自然语言。这意味着即使面对非英文发音的内容也能保持较高的准确性。如果想要让系统自动检测所给定录音片段的语言种类,则可以在调用 transcribe 方法时加入额外参数 language=None 即可。 ```python result_auto_lang_detect = model.transcribe(audio_file_path, language=None) detected_language_code = result_auto_lang_detect['language'] transcribed_text = result_auto_lang_detect["text"] print(f"Detected Language Code: {detected_language_code}") print(transcribed_text) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

袁菲李

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值