关于WeNet的训练,官方已经给出了比较详细的步骤,不过实际在使用的时候,还是多少会踩几个坑,所以这里就结合官网推荐步骤以及自己的经验,来整理一下训练流程。
官网的训练步骤可参考:How to train models?
我用的是AIShell数据集,所以这里就以AIShell为例来说明。
example/aishell/s0/run.sh中给出了全部的步骤,理论上可以通过如下命令行一键执行:
bash run.sh --stage -1 --stop_stage 6
但由于数据集过大,并且中间一定会出错,因此建议大家还是一步一步分开执行。
如果你使用Linux操作系统,在开始一步步执行之前,如果出现格式错误,可以通过在vim中执行如下命令,将Windows上的文件格式转换成Linux格式:
:set fileformat=unix
第一步:下载数据
执行命令:
bash run.sh --stage -1 --stop_stage -1
这一步对应的是run.sh中的Stage -1,脚本内容如下:
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
local/download_and_untar.sh ${data} ${data_url} data_aishell
local/download_and_untar.sh ${data} ${data_url} resource_aishell
fi
由于数据文件太大,这一步建议大家到网址Index of /resources/33手动下载下载并解压,然后在run.sh中根据数据存放位置,修改data变量,改成自己实际存放数据的文件目录即可。
第二步:准备训练数据
执行命令:
bash run.sh --stage 0 --stop_stage 0
这一步的脚本如下:
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# Data preparation
local/aishell_data_prep.sh ${data}/data_aishell/wav ${data}/data_aishell/transcript
fi
通过调用local/aishell_data_prep.sh,将原始aishell数据集划分成train、dev、test三个子集。每个子集包含两个文件:wav.scp和text。wav.scp内容如下,第一列为wav_id,第二列为该wav所在的文件路径。
BAC009S0002W0122 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
BAC009S0002W0124 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
BAC009S0002W0125 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0125.wav
...
text内容如下,第一列为wav_id,第二列为该wav对应的文字转录。
BAC009S0002W0259 并举 办 为期 三天 的 会议
BAC009S0002W0261 从 最 新机 器 人 到 物 联网 娱乐 可 穿戴 设备 等
BAC009S0002W0262 听取 部分 最 具 创新 思维 的人 士 提供 的 深刻 见解
BAC009S0002W0263 从而 推动 我们 全球 行业 的 发展
BAC009S0002W0264 大会 会议 具体 包括
......
以上文字转录中包含空格,会在下一步中对空格进行处理。
第三步:去空格,计算cmvn
这一步执行命令:
bash run.sh --stage 1 --stop_stage 1
脚本内容如下:
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# remove the space between the text labels for Mandarin dataset
for x in train dev test; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \
<(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
done
../../../tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \
--in_scp data/${train_set}/wav.scp \
--out_cmvn data/$train_set/global_cmvn
fi
该步骤实现两个功能,首先,对train、dev、test三个数据集中的text进行去空格操作,接着,执行../../../tools/compute_cmvn_stats.py文件,计算训练集的cmvn(Cepstral Mean and Variance Normalization)。计算完成后,在训练集文件夹下,会看到一个叫global_cmvn的文件。
第四步:生成label token字典
字典是中文字符向整数的映射。这一步的执行命令行为:
bash run.sh --stage 2 --stop_stage 2
脚本内容如下:
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Make a dictionary"
mkdir -p $(dirname $dict)
echo "<blank> 0" > ${dict} # 0 is for "blank" in CTC
echo "<unk> 1" >> ${dict} # <unk> must be 1
echo "<sos/eos> 2" >> $dict
../../../tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
| tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
awk '{print $0 " " NR+2}' >> ${dict}
fi
大体原理为:从训练集text中抽取所有的中文汉字,经过排序、去重之后,给每个汉字进行整数编码,生成的最终字典命名为lang_char.txt,内容前几行如下:
在aishell的训练上,最终生成了4233个字符到数字的映射关系。其中,前三个字符固定不变,分别代表(CTC输出的)空字符<blank>、字典范围之外的字符<unk>以及语音开始及结束符号<sos/eos>,二者共享一个id。
第五步:准备WeNet数据格式
执行命令:
bash run.sh --stage 3 --stop_stage 3
对应的脚本内容如下:
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare data, prepare required format"
for x in dev test ${train_set}; do
if [ $data_type == "shard" ]; then
../../../tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \
--num_threads 16 data/$x/wav.scp data/$x/text \
$(realpath data/$x/shards) data/$x/data.list
else
../../../tools/make_raw_list.py data/$x/wav.scp data/$x/text \
data/$x/data.list
fi
done
fi
这一步用来生成WeNet需要的数据格式data.list,该文件以json格式存储,包含三个字段:
- key: 语音文件的key;
- wav:key对应的语音文件的路径;
- txt:key对应的语音文件的文字转录。
生成的data.list内容示例如下:
{"key": "BAC009S0002W0122", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0122.wav", "txt": "而对楼市成交抑制作用最大的限购"}
{"key": "BAC009S0002W0123", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0123.wav", "txt": "也成为地方政府的眼中钉"}
{"key": "BAC009S0002W0124", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0124.wav", "txt": "自六月底呼和浩特市率先宣布取消限购后"}
注:aishell默认使用的数据格式为raw,对于1万小时以上的超大数据集,可以使用shard格式,此处略过,有兴趣的同学可以去官网查看具体用法。
第六步:模型训练
这一步的命令行为:
bash run.sh --stage 4 --stop_stage 4
对应的脚本内容:
train_config=conf/train_conformer.yaml
dir=exp/conformer
tensorboard_dir=tensorboard
checkpoint=
num_workers=8
prefetch=10
# ......
train_engine=torch_ddp
deepspeed_config=conf/ds_stage2.json
deepspeed_save_states="model_only"
# ......
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
mkdir -p $dir
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
# NOTE(xcsong): deepspeed fails with gloo, see
# https://github.com/microsoft/DeepSpeed/issues/2818
dist_backend="nccl"
# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
# and export.
if [ ${train_engine} == "deepspeed" ]; then
echo "$0: using deepspeed"
else
echo "$0: using torch ddp"
fi
# NOTE(xcsong): Both ddp & deepspeed can be launched by torchrun
# NOTE(xcsong): To unify single-node & multi-node training, we add
# all related args. You should change `nnodes` &
# `rdzv_endpoint` for multi-node, see
# https://pytorch.org/docs/stable/elastic/run.html#usage
# https://github.com/wenet-e2e/wenet/pull/2055#issuecomment-1766055406
# `rdzv_id` - A user-defined id that uniquely identifies the worker group for a job.
# This id is used by each node to join as a member of a particular worker group.
# `rdzv_endpoint` - The rendezvous backend endpoint; usually in form <host>:<port>.
# NOTE(xcsong): In multi-node training, some clusters require special NCCL variables to set prior to training.
# For example: `NCCL_IB_DISABLE=1` + `NCCL_SOCKET_IFNAME=enp` + `NCCL_DEBUG=INFO`
# without NCCL_IB_DISABLE=1
# RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL Version xxx
# without NCCL_SOCKET_IFNAME=enp (IFNAME could be get by `ifconfig`)
# RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:xxx
# ref: https://github.com/google/jax/issues/13559#issuecomment-1343573764
echo "$0: num_nodes is $num_nodes, proc_per_node is $num_gpus"
torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint=$HOST_NODE_ADDR \
../../../wenet/bin/train.py \
--train_engine ${train_engine} \
--config $train_config \
--data_type $data_type \
--train_data data/$train_set/data.list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--tensorboard_dir ${tensorboard_dir} \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
fi
torch_ddp作为train_engine可以直接训练,deepspeed需要有一个文件ds_stage2.json,由于WeNet源代码中缺少这个文件,因此我就使用默认的torch_ddp来训的。我的训练机器有两张GPU卡,训练时全部用上。
以上脚本默认每次从epoch0开始训练,如果想要从某个已有的checkpoint开始训练,可以在以上脚本中指定checkpoint=exp/your_exp/$n.pt,这样在训练时就可以从$n+1.pt开始生成checkpoint了。
如果想要修改神经网络结构、优化参数、损失参数以及数据集,可以通过更改训练配置文件来实现,比如我使用的时conf/train_conformer.yaml文件,可通过修改该文件来实现训练模型相关参数的更改。
第七步:模型验证
训练完成之后,生成一系列的pt文件。该步完成两个操作:一是对在交叉验证中最好的${average_num}个模型进行average;二是使用average的输出模型对test数据集进行推理验证。
命令行如下:
bash run.sh --stage 5 --stop_stage 5
对应的脚本内容如下:
# use average_checkpoint will get better result
average_checkpoint=true
decode_checkpoint=$dir/final.pt
average_num=30
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"
# ......
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# Test model, please specify the model you want to test by --checkpoint
if [ ${average_checkpoint} == true ]; then
decode_checkpoint=$dir/avg_${average_num}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python ../../../wenet/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path $dir \
--num ${average_num} \
--val_best
fi
# Please specify decoding_chunk_size for unified streaming and
# non-streaming model. The default value is -1, which is full chunk
# for non-streaming inference.
decoding_chunk_size=
ctc_weight=0.3
reverse_weight=0.5
python ../../../wenet/bin/recognize.py --gpu 0 \
--modes $decode_modes \
--config $dir/train.yaml \
--data_type $data_type \
--test_data data/test/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 32 \
--blank_penalty 0.0 \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_dir $dir \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
for mode in ${decode_modes}; do
python ../../../tools/compute-wer.py --char=1 --v=1 \
data/test/text $dir/$mode/text > $dir/$mode/wer
done
fi
其中,模型的解码模式可通过decode_modes变量设置,默认四种模式都开启:ctc_greedy_search,ctc_prefix_beam_search,attention,attention_rescoring。
推理完成后,可到对应的解码模式命名的文件夹下去查看模型的wer,例如,我的attention_rescoring的评估结果在exp/conformer/attention_rescoring下。
第八步:模型导出
这一步是通过Libtorch导出训练好的模型,以方便产品化的应用,导出的模型可用于其他变成语言,比如C++。
该步的命令行:
bash run.sh --stage 6 --stop_stage 6
对应的脚本内容:
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# Export the best model you want
python ../../../wenet/bin/export_jit.py \
--config $dir/train.yaml \
--checkpoint $dir/avg_${average_num}.pt \
--output_file $dir/final.zip \
--output_quant_file $dir/final_quant.zip
fi
该步最终输出final.zip和final_quant.zip两个模型文件。
好了,对于普通的模型开发工作,以上几个步骤已经足够了。对于想要做更多实验的同学,可以对照论文,以及run.sh内容做更多尝试,我在这里就不多赘述了。