中文自动语音识别
https://download.csdn.net/download/qq_41854731/13108319.
模型
cnn + ctc
gru + ctc
训练环境
tensorflow 1.+
keras
数据
将训练数据整理成kaldi中的格式
至少需要 text 和 wav.scp 两个文件
训练
python train_am.py
import os
import tensorflow as tf
from utils_am import get_data, data_hparams, dic2args
from keras.callbacks import ModelCheckpoint
from keras.backend import tensorflow_backend
import yaml
import time
import argparse
# from cnn_ctc import Am, am_hparams
from gru_ctc import Am, am_hparams
# 自动分配GPU内存,以防止Gpu内存不够的情况
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tensorflow_backend.set_session(tf.Session(config=config))
# 动态参数
parser = argparse.ArgumentParser()
parser.add_argument(
'-c', '--config', type=str, default='./config/conf_am.yaml')
parser.add_argument(
'-s', '--save_path', type=str, default='save_models/'+time.strftime("%m-%d-%H-%M-%S")+"-logs")
cmd_args = parser.parse_args()
## 加载整个训练必要的config
f = open(cmd_args.config, 'r', encoding='utf-8')
parms_dict = yaml.load(f, Loader=yaml.FullLoader)
f.close()
## 训练数据参数
data_args = dic2args(parms_dict['data'])
train_data = get_data(data_args)
batch_num = train_data.batch_num
train_batch = train_data.get_am_batch()
## 准备验证所需数据
validation_data = None
validation_steps = None
if parms_dict['data']["dev_path"]:
dev_args = dic2args(parms_dict['data'])
dev_args.data_path = parms_dict['data']["dev_path"]
dev_data = get_data(dev_args)
dev_batch = dev_data.get_am_batch()
validation_data = dev_batch
validation_steps = 50
## 声学模型训
am_args = dic2args(parms_dict['model'])
am_args.vocab_size = len(train_data.am_vocab)
am = Am(am_args)
## 训练参数
epochs = parms_dict['train']['epochs']
save_path = parms_dict['train']['save_path'] = cmd_args.save_path
retrain_dir = parms_dict['train']['retrain_dir']
## save vocab and config
os.makedirs(save_path,exist_ok=True)
# 保存vocab
am_vocab = train_data.am_vocab
f = open(os.path.join(save_path,"vocab"),"w")
f.write("\n".join(am_vocab))
f.close()
# 保存config
parms_dict["data"]["am_vocab_file"] = os.path.join(save_path,"vocab")
f = open(os.path.join(save_path,"config.yaml"),"w",encoding='utf-8')
yaml.dump(parms_dict,f)
f.close()
## 是否加载预训练模型
if retrain_dir:
print('load acoustic model...')
am.ctc_model.load_weights(retrain_dir)
## checkpoint,保存模型信息
ckpt = "model_{epoch:02d}-{loss:.2f}.h5"
checkpoint = ModelCheckpoint(os.path.join(save_path, ckpt), monitor='val_loss', save_weights_only=True, verbose=1, save_best_only=False)
## 开始训练
am.ctc_model.fit_generator(
train_batch, # 打包好的迭代型训练数据
steps_per_epoch=batch_num, # 一个epoch训练多少个batch
epochs=epochs, # 训练多少个epoch
callbacks=[checkpoint], # 保存的model形式
workers=1,
use_multiprocessing=False,
validation_data=validation_data,
validation_steps=validation_steps)
## 保存最后一个训练epoch的模型
am.ctc_model.save_weights(os.path.join(save_path,'model_'+str(epochs).zfill(2)+'.h5'))
测试
python test_am.py
import numpy as np
from utils_am import decode_ctc, compute_mfcc
from gru_ctc import Am, am_hparams
import math
import argparse
# 动态参数
parser = argparse.ArgumentParser()
parser.add_argument(
'-v', '--vocab_file', type=str, default='config/conf_am.yaml')
parser.add_argument(
'-m', '--model_file', type=str, default='save_models/11-12-14-14-34-logs/model_05.h5')
parser.add_argument(
'-w', '--wav_file', type=str, default='test.wav')
cmd_args = parser.parse_args()
# 加载vocab
am_vocab = []
for s in open(cmd_args.vocab_file):
am_vocab.append(s.strip())
## 加载语音识别模型
am_args = am_hparams()
am_args.vocab_size = len(am_vocab)
am = Am(am_args)
print('loading acoustic model...')
am.ctc_model.load_weights(cmd_args.model_file)
# 生成mfcc并扩展成四维矩阵输入到模型中
mfcc = compute_mfcc(cmd_args.wav_file)
x = np.zeros((1,8*math.ceil(mfcc.shape[0]/8),mfcc.shape[1],1))
x[0,:mfcc.shape[0],:,0] = mfcc
# 预测结果
result = am.model.predict(x, steps=1)
_, text = decode_ctc(result, am_vocab)
text = ' '.join(text)
print('预测结果:', text)