深入剖析 Wenet 下 AISHELL-1 的 run.sh 脚本:从数据准备到模型训练

深入剖析 Wenet 下 AISHELL-1 的 run.sh 脚本:从数据准备到模型训练

在本教程中,我们将详细剖析 Wenet 框架下 AISHELL-1 数据集的 run.sh 脚本。这份脚本涵盖了数据准备、模型训练、解码等多个步骤,帮助你理解每个参数和变量的用途,确保你能够顺利地运行和修改脚本以符合你的需求。

脚本概述

run.sh 脚本用于处理 AISHELL-1 数据集的整个训练和测试流程。我们将逐步分析脚本中的每个部分,包括环境变量、参数设置、各个阶段的执行等。


1. 脚本头部

#!/bin/bash

# Copyright 2019 Mobvoi Inc. All Rights Reserved.
. ./path.sh || exit 1;
  • #!/bin/bash: 指定该脚本使用 Bash 作为解释器。
  • # 开头的行是注释,解释内容。
  • . ./path.sh || exit 1;: 加载 path.sh 文件,设置必要的环境变量和路径配置,若加载失败则退出脚本。

2. 自动检测 GPU

if command -v nvidia-smi &> /dev/null; then
  num_gpus=$(nvidia-smi -L | wc -l)
  gpu_list=$(seq -s, 0 $((num_gpus-1)))
else
  num_gpus=-1
  gpu_list="-1"
fi
  • 检查系统是否安装 NVIDIA 的 GPU 驱动 (nvidia-smi)。
  • num_gpus: 存储可用 GPU 的数量。
  • gpu_list: 存储 GPU 的编号(如 “0,1,2”),用于后续训练时指定使用的 GPU。

3. CUDA 设备配置

export CUDA_VISIBLE_DEVICES="0,2,4"
echo "CUDA_VISIBLE_DEVICES is ${CUDA_VISIBLE_DEVICES}"
  • CUDA_VISIBLE_DEVICES: 手动指定要使用的 GPU 设备,例子中指定使用第 0、2、4 号 GPU。
  • 打印当前的 CUDA 可见设备,以便用户确认。

4. 定义训练阶段和参数

stage=0 # start from 0 if you need to start from data preparation
stop_stage=5
  • stage: 设置脚本的起始阶段(如数据准备)。
  • stop_stage: 设置脚本的结束阶段。

5. 其他训练参数设置

变量解释

HOST_NODE_ADDR="localhost:0"
num_nodes=1
job_id=2023

data=~/project/data/aishell
data_url=www.openslr.org/resources/33

nj=16
dict=data/dict/lang_char.txt

data_type=raw
num_utts_per_shard=1000

train_set=train
train_config=conf/train_conformer.yaml
dir=exp/conformer
tensorboard_dir=tensorboard
checkpoint=
num_workers=8
prefetch=10

average_checkpoint=true
decode_checkpoint=$dir/final.pt
average_num=9
decode_modes="ctc_prefix_beam_search attention attention_rescoring"
train_engine=torch_ddp

deepspeed_config=conf/ds_stage2.json
deepspeed_save_states="model_only"
  • HOST_NODE_ADDR: 主机节点的地址,通常为 localhost:0
  • num_nodes: 训练的节点数,默认为 1(即单节点训练)。
  • job_id: 任务 ID,用于标识不同的训练任务。

数据参数

  • data: 指向 AISHELL 数据集的路径。
  • data_url: 数据集下载链接。

训练参数

  • nj: 设置并行处理的工作线程数,通常为 16。
  • dict: 字典文件的路径,用于存储字符与索引的映射。
  • data_type: 数据类型(rawshard)。
  • num_utts_per_shard: 每个数据分片的语音样本数量(如 1000)。

训练和输出配置

  • train_set: 训练集名称,通常为 train
  • train_config: 模型训练的配置文件路径。
  • dir: 输出模型和实验结果的目录。
  • tensorboard_dir: 用于存储 TensorBoard 日志的目录。
  • checkpoint: 可选的检查点路径,用于恢复训练。
  • num_workers: 数据加载时的工作线程数(如 8)。
  • prefetch: 预读取的样本数量。

模型相关参数

  • average_checkpoint: 是否启用模型平均,默认为 true
  • decode_checkpoint: 解码模型的检查点路径。
  • average_num: 模型平均时使用的模型数量(如 9)。
  • decode_modes: 解码模式列表。
  • train_engine: 指定使用的训练引擎(如 torch_ddp)。
  • deepspeed_config: DeepSpeed 配置文件路径。
  • deepspeed_save_states: DeepSpeed 保存的状态设置。

6. 解析命令行参数

. tools/parse_options.sh || exit 1;
  • 加载 parse_options.sh 脚本,解析命令行传入的参数,以更灵活地控制脚本行为。

7. 各个阶段的执行

7.1 阶段 -1: 数据下载

if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
  echo "stage -1: Data Download"
  local/download_and_untar.sh ${data} ${data_url} data_aishell
  local/download_and_untar.sh ${data} ${data_url} resource_aishell
fi
  • 在阶段 -1,下载 AISHELL 数据集及其资源。

7.2 阶段 0: 数据准备

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
  local/aishell_data_prep.sh ${data}/data_aishell/wav ${data}/data_aishell/transcript
fi
  • 调用 local/aishell_data_prep.sh 脚本进行数据准备,将音频和转录文本整理成指定格式。

7.3 阶段 1: 处理文本标签

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
  for x in train dev test; do
    cp data/${x}/text data/${x}/text.org
    paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \
      <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
      > data/${x}/text
    rm data/${x}/text.org
  done

  tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \
    --in_scp data/${train_set}/wav.scp \
    --out_cmvn data/$train_set/global_cmvn
fi
  • 处理文本标签,去掉空格,生成新的文本文件。
  • 计算 CMVN(均值和方差归一化)统计信息。

7.4 阶段 2: 创建字典

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
  echo "Make a dictionary"
  mkdir -p $(dirname $dict)
  echo "<blank> 0" > ${dict}  # 0 is for "blank" in CTC
  echo "<unk> 1"  >> ${dict}  # <unk> must be 1
  echo "<sos/eos> 2" >> $dict
  tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
    | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
    awk '{print $0 " " NR+2}' >> ${dict}
fi
  • 创建字典文件,定义 <blank><unk><sos/eos>
  • 使用 text2token.py 工具生成字典中的其他字符。

7.5 阶段 3: 准备数据格式

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
  echo "Prepare data, prepare required format"
  for x in dev test ${train_set}; do
    if [ $data_type == "shard" ]; then
      tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \
        --num_threads 16 data/$x/wav.scp data/$x/text \
        $(realpath data/$x/shards) data/$x/data.list
    else
      tools/make_raw_list.py data/$x/wav.scp data/$x/text \
        data/$x/data.list
    fi
  done
fi
  • 生成数据列表,准备训练所需的格式。根据 data_type 的设置选择处理方式。

7.6 阶段 4: 训练模型

if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
  mkdir -p $dir
  num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
  dist_backend="nccl"

  echo "$0: num_nodes is $num_nodes, proc_per_node is $num_gpus"
  torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus \
           --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint=$HOST_NODE_ADDR \
    wenet/bin/train.py \
      --train_engine ${train_engine} \
      --config $train_config \
      --data_type  $data_type \
      --train_data data/$train_set/data.list \
      --cv_data data/dev/data.list \
      ${checkpoint:+--checkpoint $checkpoint} \
      --model_dir $dir \
      --tensorboard_dir ${tensorboard_dir} \
      --ddp.dist_backend $dist_backend \
      --num_workers ${num_workers} \
      --prefetch ${prefetch} \
      --pin_memory \
      --deepspeed_config ${deepspeed_config} \
      --deepspeed.save_states ${deepspeed_save_states}
fi
  • 创建输出目录并设置训练参数。
  • 使用 torchrun 启动训练,传递必要的参数。

7.7 阶段 5: 模型测试

if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
  if [ ${average_checkpoint} == true ]; then
    decode_checkpoint=$dir/avg_${average_num}.pt
    echo "do model average and final checkpoint is $decode_checkpoint"
    python wenet/bin/average_model.py \
      --dst_model $decode_checkpoint \
      --src_path $dir  \
      --num ${average_num} \
      --val_best
  fi
  decoding_chunk_size=
  ctc_weight=0.3
  reverse_weight=0.5
  python wenet/bin/recognize.py --gpu 0 \
    --modes $decode_modes \
    --config $dir/train.yaml \
    --data_type $data_type \
    --test_data data/test/data.list \
    --checkpoint $decode_checkpoint \
    --beam_size 10 \
    --batch_size 32 \
    --blank_penalty 0.0 \
    --ctc_weight $ctc_weight \
    --reverse_weight $reverse_weight \
    --result_dir $dir \
    ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
  for mode in ${decode_modes}; do
    python tools/compute-wer.py --char=1 --v=1 \
      data/test/text $dir/$mode/text > $dir/$mode/wer
  done
fi
  • 如果启用模型平均,执行模型平均操作,生成平均模型的检查点。
  • 使用 wenet/bin/recognize.py 进行模型解码测试并计算字错误率(WER)。

7.8 阶段 6: 导出模型

if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
  python wenet/bin/export_jit.py \
    --config $dir/train.yaml \
    --checkpoint $dir/avg_${average_num}.pt \
    --output_file $dir/final.zip \
    --output_quant_file $dir/final_quant.zip
fi
  • 导出训练好的模型,以便于后续使用和推理。

7.9 阶段 7: 语言模型准备(可选)

if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
  # 7.1 Prepare dict
  unit_file=$dict
  mkdir -p data/local/dict
  cp $unit_file data/local/dict/units.txt
  tools/fst/prepare_dict.py $unit_file ${data}/resource_aishell/lexicon.txt \
    data/local/dict/lexicon.txt
  # 7.2 Train lm
  lm=data/local/lm
  mkdir -p $lm
  tools/filter_scp.pl data/train/text \
    $data/data_aishell/transcript/aishell_transcript_v0.8.txt > $lm/text
  local/aishell_train_lms.sh
  # 7.3 Build decoding TLG
  tools/fst/compile_lexicon_token_fst.sh \
    data/local/dict data/local/tmp data/local/lang
  tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
  # 7.4 Decoding with runtime
  chunk_size=-1
  ./tools/decode.sh --nj 16 \
    --beam 15.0 --lattice_beam 7.5 --max_active 7000 \
    --blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \
    --chunk_size $chunk_size \
    --fst_path data/lang_test/TLG.fst \
    --dict_path data/lang_test/words.txt \
    data/test/wav.scp data/test/text $dir/final.zip \
    data/lang_test/units.txt $dir/lm_with_runtime
fi
  • 准备语言模型,构建解码时所需的 TLG(转移图)。

7.10 阶段 8: 使用 k2 HLG 解码(可选)

if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
  if [ ! -f data/local/lm/lm.arpa ]; then
    echo "Please run prepare dict and train lm in Stage 7" || exit 1;
  fi

  # 8.1 Build decoding HLG
  required="data/local/hlg/HLG.pt data/local/hlg/words.txt"
  for f in $required; do
    if [ ! -f $f ]; then
      tools/k2/make_hlg.sh data/local/dict/ data/local/lm/ data/local/hlg
      break
    fi
  done

  # 8.2 Decode using HLG
  decoding_chunk_size=
  lm_scale=0.7
  decoder_scale=0.1
  r_decoder_scale=0.7
  decode_modes="hlg_onebest hlg_rescore"
  python wenet/bin/recognize.py --gpu 0 \
    --modes $decode_modes \
    --config $dir/train.yaml \
    --data_type $data_type \
    --test_data data/test/data.list \
    --checkpoint $decode_checkpoint \
    --beam_size 10 \
    --batch_size 16 \
    --blank_penalty 0.0 \
    --dict $dict \
    --word data/local/hlg/words.txt \
    --hlg data/local/hlg/HLG.pt \
    --lm_scale $lm_scale \
    --decoder_scale $decoder_scale \
    --r_decoder_scale $r_decoder_scale \
    --result_dir $dir \
    ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
  for mode in ${decode_modes}; do
    python tools/compute-wer.py --char=1 --v=1 \
      data/test/text $dir/$mode/text > $dir/$mode/wer
  done
fi
  • 构建 HLG(隐状态图)的解码图,并执行解码操作。

结语

通过对 Wenet 下 AISHELL-1 的 run.sh 脚本的详细剖析,我们希望你能更好地理解每个步骤的功能及其背后的原理。此脚本不仅为模型训练提供了便利,也为用户自定义和扩展训练流程提供了基础。希望本教程能帮助你顺利完成语音识别项目!

  • 21
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

帅小柏

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值