WeNet训练流程整理

关于WeNet的训练,官方已经给出了比较详细的步骤,不过实际在使用的时候,还是多少会踩几个坑,所以这里就结合官网推荐步骤以及自己的经验,来整理一下训练流程。

官网的训练步骤可参考:How to train models?

我用的是AIShell数据集,所以这里就以AIShell为例来说明。

example/aishell/s0/run.sh中给出了全部的步骤,理论上可以通过如下命令行一键执行:

bash run.sh --stage -1 --stop_stage 6

但由于数据集过大,并且中间一定会出错,因此建议大家还是一步一步分开执行。

如果你使用Linux操作系统,在开始一步步执行之前,如果出现格式错误,可以通过在vim中执行如下命令,将Windows上的文件格式转换成Linux格式:

:set fileformat=unix

第一步:下载数据

执行命令:

bash run.sh --stage -1 --stop_stage -1

这一步对应的是run.sh中的Stage -1,脚本内容如下:

if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
  echo "stage -1: Data Download"
  local/download_and_untar.sh ${data} ${data_url} data_aishell
  local/download_and_untar.sh ${data} ${data_url} resource_aishell
fi

由于数据文件太大,这一步建议大家到网址Index of /resources/33手动下载下载并解压,然后在run.sh中根据数据存放位置,修改data变量,改成自己实际存放数据的文件目录即可。

第二步:准备训练数据

执行命令:

bash run.sh --stage 0 --stop_stage 0

这一步的脚本如下:

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
  # Data preparation
  local/aishell_data_prep.sh ${data}/data_aishell/wav ${data}/data_aishell/transcript
fi

通过调用local/aishell_data_prep.sh,将原始aishell数据集划分成train、dev、test三个子集。每个子集包含两个文件:wav.scp和text。wav.scp内容如下,第一列为wav_id,第二列为该wav所在的文件路径。

BAC009S0002W0122 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
BAC009S0002W0124 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
BAC009S0002W0125 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0125.wav
...

text内容如下,第一列为wav_id,第二列为该wav对应的文字转录。

BAC009S0002W0259 并举 办 为期 三天 的 会议
BAC009S0002W0261 从 最 新机 器 人 到 物 联网 娱乐 可 穿戴 设备 等
BAC009S0002W0262 听取 部分 最 具 创新 思维 的人 士 提供 的 深刻 见解
BAC009S0002W0263 从而 推动 我们 全球 行业 的 发展
BAC009S0002W0264 大会 会议 具体 包括
......

以上文字转录中包含空格,会在下一步中对空格进行处理。 

第三步:去空格,计算cmvn

这一步执行命令:

bash run.sh --stage 1 --stop_stage 1

脚本内容如下:

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
  # remove the space between the text labels for Mandarin dataset
  for x in train dev test; do
    cp data/${x}/text data/${x}/text.org
    paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \
      <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
      > data/${x}/text
    rm data/${x}/text.org
  done

  ../../../tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \
    --in_scp data/${train_set}/wav.scp \
    --out_cmvn data/$train_set/global_cmvn
fi

该步骤实现两个功能,首先,对train、dev、test三个数据集中的text进行去空格操作,接着,执行../../../tools/compute_cmvn_stats.py文件,计算训练集的cmvn(Cepstral Mean and Variance Normalization)。计算完成后,在训练集文件夹下,会看到一个叫global_cmvn的文件。

第四步:生成label token字典

字典是中文字符向整数的映射。这一步的执行命令行为:

bash run.sh --stage 2 --stop_stage 2

脚本内容如下:

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
  echo "Make a dictionary"
  mkdir -p $(dirname $dict)
  echo "<blank> 0" > ${dict}  # 0 is for "blank" in CTC
  echo "<unk> 1"  >> ${dict}  # <unk> must be 1
  echo "<sos/eos> 2" >> $dict
  ../../../tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
    | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
    awk '{print $0 " " NR+2}' >> ${dict}
fi

大体原理为:从训练集text中抽取所有的中文汉字,经过排序、去重之后,给每个汉字进行整数编码,生成的最终字典命名为lang_char.txt,内容前几行如下:

在aishell的训练上,最终生成了4233个字符到数字的映射关系。其中,前三个字符固定不变,分别代表(CTC输出的)空字符<blank>、字典范围之外的字符<unk>以及语音开始及结束符号<sos/eos>,二者共享一个id。

第五步:准备WeNet数据格式

执行命令:

bash run.sh --stage 3 --stop_stage 3

对应的脚本内容如下:

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
  echo "Prepare data, prepare required format"
  for x in dev test ${train_set}; do
    if [ $data_type == "shard" ]; then
      ../../../tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \
        --num_threads 16 data/$x/wav.scp data/$x/text \
        $(realpath data/$x/shards) data/$x/data.list
    else
      ../../../tools/make_raw_list.py data/$x/wav.scp data/$x/text \
        data/$x/data.list
    fi
  done
fi

这一步用来生成WeNet需要的数据格式data.list,该文件以json格式存储,包含三个字段:

  • key: 语音文件的key;
  • wav:key对应的语音文件的路径;
  • txt:key对应的语音文件的文字转录。

生成的data.list内容示例如下:

{"key": "BAC009S0002W0122", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0122.wav", "txt": "而对楼市成交抑制作用最大的限购"}
{"key": "BAC009S0002W0123", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0123.wav", "txt": "也成为地方政府的眼中钉"}
{"key": "BAC009S0002W0124", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0124.wav", "txt": "自六月底呼和浩特市率先宣布取消限购后"}

注:aishell默认使用的数据格式为raw,对于1万小时以上的超大数据集,可以使用shard格式,此处略过,有兴趣的同学可以去官网查看具体用法。 

第六步:模型训练

这一步的命令行为:

bash run.sh --stage 4 --stop_stage 4

对应的脚本内容:

train_config=conf/train_conformer.yaml
dir=exp/conformer
tensorboard_dir=tensorboard
checkpoint=
num_workers=8
prefetch=10

# ......

train_engine=torch_ddp

deepspeed_config=conf/ds_stage2.json
deepspeed_save_states="model_only"

# ......

if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
  mkdir -p $dir
  num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
  # Use "nccl" if it works, otherwise use "gloo"
  # NOTE(xcsong): deepspeed fails with gloo, see
  #   https://github.com/microsoft/DeepSpeed/issues/2818
  dist_backend="nccl"

  # train.py rewrite $train_config to $dir/train.yaml with model input
  # and output dimension, and $dir/train.yaml will be used for inference
  # and export.
  if [ ${train_engine} == "deepspeed" ]; then
    echo "$0: using deepspeed"
  else
    echo "$0: using torch ddp"
  fi

  # NOTE(xcsong): Both ddp & deepspeed can be launched by torchrun
  # NOTE(xcsong): To unify single-node & multi-node training, we add
  #               all related args. You should change `nnodes` &
  #               `rdzv_endpoint` for multi-node, see
  #               https://pytorch.org/docs/stable/elastic/run.html#usage
  #               https://github.com/wenet-e2e/wenet/pull/2055#issuecomment-1766055406
  #               `rdzv_id` - A user-defined id that uniquely identifies the worker group for a job.
  #                           This id is used by each node to join as a member of a particular worker group.
  #               `rdzv_endpoint` - The rendezvous backend endpoint; usually in form <host>:<port>.
  # NOTE(xcsong): In multi-node training, some clusters require special NCCL variables to set prior to training.
  #               For example: `NCCL_IB_DISABLE=1` + `NCCL_SOCKET_IFNAME=enp` + `NCCL_DEBUG=INFO`
  #               without NCCL_IB_DISABLE=1
  #                   RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL Version xxx
  #               without NCCL_SOCKET_IFNAME=enp  (IFNAME could be get by `ifconfig`)
  #                   RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:xxx
  #               ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
  echo "$0: num_nodes is $num_nodes, proc_per_node is $num_gpus"
  torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus \
           --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint=$HOST_NODE_ADDR \
    ../../../wenet/bin/train.py \
      --train_engine ${train_engine} \
      --config $train_config \
      --data_type  $data_type \
      --train_data data/$train_set/data.list \
      --cv_data data/dev/data.list \
      ${checkpoint:+--checkpoint $checkpoint} \
      --model_dir $dir \
      --tensorboard_dir ${tensorboard_dir} \
      --ddp.dist_backend $dist_backend \
      --num_workers ${num_workers} \
      --prefetch ${prefetch} \
      --pin_memory \
      --deepspeed_config ${deepspeed_config} \
      --deepspeed.save_states ${deepspeed_save_states}
fi

torch_ddp作为train_engine可以直接训练,deepspeed需要有一个文件ds_stage2.json,由于WeNet源代码中缺少这个文件,因此我就使用默认的torch_ddp来训的。我的训练机器有两张GPU卡,训练时全部用上。

以上脚本默认每次从epoch0开始训练,如果想要从某个已有的checkpoint开始训练,可以在以上脚本中指定checkpoint=exp/your_exp/$n.pt,这样在训练时就可以从$n+1.pt开始生成checkpoint了。

如果想要修改神经网络结构、优化参数、损失参数以及数据集,可以通过更改训练配置文件来实现,比如我使用的时conf/train_conformer.yaml文件,可通过修改该文件来实现训练模型相关参数的更改。

第七步:模型验证

训练完成之后,生成一系列的pt文件。该步完成两个操作:一是对在交叉验证中最好的${average_num}个模型进行average;二是使用average的输出模型对test数据集进行推理验证。

命令行如下:

bash run.sh --stage 5 --stop_stage 5

对应的脚本内容如下:


# use average_checkpoint will get better result
average_checkpoint=true
decode_checkpoint=$dir/final.pt
average_num=30
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"

# ......

if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
  # Test model, please specify the model you want to test by --checkpoint
  if [ ${average_checkpoint} == true ]; then
    decode_checkpoint=$dir/avg_${average_num}.pt
    echo "do model average and final checkpoint is $decode_checkpoint"
    python ../../../wenet/bin/average_model.py \
      --dst_model $decode_checkpoint \
      --src_path $dir  \
      --num ${average_num} \
      --val_best
  fi
  # Please specify decoding_chunk_size for unified streaming and
  # non-streaming model. The default value is -1, which is full chunk
  # for non-streaming inference.
  decoding_chunk_size=
  ctc_weight=0.3
  reverse_weight=0.5
  python ../../../wenet/bin/recognize.py --gpu 0 \
    --modes $decode_modes \
    --config $dir/train.yaml \
    --data_type $data_type \
    --test_data data/test/data.list \
    --checkpoint $decode_checkpoint \
    --beam_size 10 \
    --batch_size 32 \
    --blank_penalty 0.0 \
    --ctc_weight $ctc_weight \
    --reverse_weight $reverse_weight \
    --result_dir $dir \
    ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
  for mode in ${decode_modes}; do
    python ../../../tools/compute-wer.py --char=1 --v=1 \
      data/test/text $dir/$mode/text > $dir/$mode/wer
  done
fi

其中,模型的解码模式可通过decode_modes变量设置,默认四种模式都开启:ctc_greedy_search,ctc_prefix_beam_search,attention,attention_rescoring。

推理完成后,可到对应的解码模式命名的文件夹下去查看模型的wer,例如,我的attention_rescoring的评估结果在exp/conformer/attention_rescoring下。

第八步:模型导出

这一步是通过Libtorch导出训练好的模型,以方便产品化的应用,导出的模型可用于其他变成语言,比如C++。

该步的命令行:

bash run.sh --stage 6 --stop_stage 6

对应的脚本内容:

if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
  # Export the best model you want
  python ../../../wenet/bin/export_jit.py \
    --config $dir/train.yaml \
    --checkpoint $dir/avg_${average_num}.pt \
    --output_file $dir/final.zip \
    --output_quant_file $dir/final_quant.zip
fi

该步最终输出final.zip和final_quant.zip两个模型文件。

好了,对于普通的模型开发工作,以上几个步骤已经足够了。对于想要做更多实验的同学,可以对照论文,以及run.sh内容做更多尝试,我在这里就不多赘述了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值