数据集下载
下载地址
并解压
命令:unzip AISHELL-1_sample.zip
1.准备 wav.scp text
#!/bin/bash
. ./path.sh
# 数据集存放的位置
sample_data=/home/asr/data/wenet/examples/aishell/s0/datasets/AISHELL-1_sample
# 数据生成的地方
data=/home/asr/data/wenet/examples/aishell/s0/data_
if [ ! -d $data ];then
mkdir -p $data
fi
# 初始化
rm -rf $data/wav.scp
rm -rf $data/text
# 1.准备 wav.scp text
for sub_dir in `ls ${sample_data}`;do
wav_txt_dir=${sample_data}/${sub_dir}/${sub_dir}_mic
echo $wav_txt_dir
for file in `ls $wav_txt_dir`;do
if [ ${file#*.} != "txt" ];then
# 准备wav.scp
echo "${file%.*} $wav_txt_dir/${file%.*}.wav" >> $data/wav.scp
# echo `wc -l $data/wav.scp`
# 准备text
txt=`cat $wav_txt_dir/${file%.*}.txt`
echo "${file%.*} $txt" >> $data/text
fi
done
done
echo "wav.scp and text done!"
2.准备data.list
使用同时读取两个文件,生成data.list
# 2.准备data.list
exec 3<$data/wav.scp
exec 4<$data/text
exec 5<$data/text
rm -rf $data/data.list
while read wav <&3 && read txt <&4 && read txt1 <&5
do
key=`echo $wav | awk -F ' ' '{ printf $1}'`
wav=`echo $wav | awk -F ' ' '{ printf $2}'`
txt=`echo $txt | awk -F ' ' '{ printf $2}'`
echo "{\"key\":\"${key}\",\"wav\":\"${wav}\",\"txt\":\"${txt}\" }" >> $data/data.list
done
echo "data.list done!"
3.准备dict
使用python脚本生成dict
python get_lang_char.py > $data/lang_char.txt
get_lang_char.py
import os
text_dir = "./data_/text"
lang_char = set()
with open(text_dir,'r',encoding='utf-8') as rfile:
lines = rfile.readlines()
for line in lines:
text = line.split(" ")[1].strip("\n")
for char in text:
lang_char.add(char)
print("<blank> 0")
print("<unk> 1")
id=0
for id,char in enumerate(lang_char):
print(char,id+2)
print("<sos/eos>",id+3)
4.计算CMVN
onfig=conf/train_u2++_conformer.yaml
4.cmvn
tools/compute_cmvn_stats.py \
--num_workers 8 \
--train_config $config \
--in_scp data_/wav.scp \
--out_cmvn data_/global_cmvn
5.训练
python3 wenet/bin/train.py \
--config $config \
--data_type raw \
--symbol_table data_/lang_char.txt \
--train_data data_/data.list \
--model_dir data_/model \
--cv_data data_/data.list \
--num_workers 8 \
--cmvn data_/global_cmvn \
--pin_memory
6.合并模型
# 6.合并模型
python wenet/bin/average_model.py \
--dst_model data_/average.pt \
--src_path data_/model \
--num 30 \
--val_best
7.测试模型
python3 wenet/bin/recognize.py \
--mode "attention_rescoring" \
--config data_/model/train.yaml \
--data_type raw \
--test_data data_/data.list \
--chechpoint data_/model/final.pt \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict data_/lang_char.txt \
--ctc_weight 1.0 \
--reverse_weight 0 \
--result_file data_/result.txt
python tools/compute-wer.py --char=1 --v=1 \
data_/text data_/result.txt > data_/wer.txt
采用更小的数据集
如果自己的知识想简单的跑通流程,可以缩小数据集
head -n 100 data.list > data.list.100
head -n 100 text > text.100
head -n 100 wav.scp > wav.scp.100
查看训练过程
tensorboard --logdir tensorboard --port 12598 --bind_all
代码汇总
#!/bin/bash
. ./path.sh
# AISHELL-1_sample数据集的路径
sample_data=/home/asr/data/wenet/examples/aishell/s0/datasets/AISHELL-1_sample
# 生成data存储路径
data=/home/asr/data/wenet/examples/aishell/s0/data_
if [ ! -d $data ];then
mkdir -p $data
fi
rm -rf $data/wav.scp
rm -rf $data/text
# 1.准备 wav.scp text
for sub_dir in `ls ${sample_data}`;do
wav_txt_dir=${sample_data}/${sub_dir}/${sub_dir}_mic
echo $wav_txt_dir
for file in `ls $wav_txt_dir`;do
if [ ${file#*.} != "txt" ];then
# 准备wav.scp
echo "${file%.*} $wav_txt_dir/${file%.*}.wav" >> $data/wav.scp
# echo `wc -l $data/wav.scp`
# 准备text
txt=`cat $wav_txt_dir/${file%.*}.txt`
echo "${file%.*} $txt" >> $data/text
fi
done
done
echo "wav.scp and text done!"
# 2.准备data.list
exec 3<$data/wav.scp
exec 4<$data/text
exec 5<$data/text
rm -rf $data/data.list
while read wav <&3 && read txt <&4 && read txt1 <&5
do
key=`echo $wav | awk -F ' ' '{ printf $1}'`
wav=`echo $wav | awk -F ' ' '{ printf $2}'`
txt=`echo $txt | awk -F ' ' '{ printf $2}'`
echo "{\"key\":\"${key}\",\"wav\":\"${wav}\",\"txt\":\"${txt}\" }" >> $data/data.list
done
echo "data.list done!"
# # 3.准备dict
python get_lang_char.py > $data/lang_char.txt
config=conf/train_u2++_conformer.yaml
# # 4.cmvn
tools/compute_cmvn_stats.py \
--num_workers 8 \
--train_config $config \
--in_scp data_/wav.scp \
--out_cmvn data_/global_cmvn
# 5.训练
python3 wenet/bin/train.py \
--config $config \
--data_type raw \
--symbol_table data_/lang_char.txt \
--train_data data_/data.list \
--model_dir data_/model \
--cv_data data_/data.list \
--num_workers 8 \
--cmvn data_/global_cmvn \
--pin_memory
# 恢复训练
# python3 wenet/bin/train.py \
# --config $config \
# --data_type raw \
# --symbol_table data_/lang_char.txt \
# --train_data data_/data.list \
# --model_dir data_/model \
# --cv_data data_/data.list \
# --checkpoint data_/model/271.pt \ # 选择第多少次的模型
# --num_workers 8 \
# --cmvn data_/global_cmvn \
# --pin_memory
# 6.合并模型
python wenet/bin/average_model.py \
--dst_model data_/average.pt \
--src_path data_/model \
--num 30 \
--val_best
# 7.测试模型
python3 wenet/bin/recognize.py \
--mode "attention_rescoring" \
--config data_/model/train.yaml \
--data_type raw \
--test_data data_/data.list.1 \
--checkpoint data_/model/final.pt \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict data_/lang_char.txt \
--ctc_weight 1.0 \
--reverse_weight 0 \
--result_file data_/result.txt
# 计算wer
python tools/compute-wer.py --char=1 --v=1 \
data_/text.1 data_/result.txt > data_/wer.txt
# # 8.导出模型
python wenet/bin/export_jit.py \
--config data_/model/train.yaml \
--checkpoint data_/average.pt \
--output_file data_/final.zip \
--output_quant_file data_/final_quant.zip