使用gru+ctc实现中文语音识别

中文自动语音识别

https://download.csdn.net/download/qq_41854731/13108319.

模型

cnn + ctc
gru + ctc

训练环境

tensorflow 1.+
keras

数据

将训练数据整理成kaldi中的格式
至少需要 text 和 wav.scp 两个文件

训练

python train_am.py

import os
import tensorflow as tf
from utils_am import get_data, data_hparams, dic2args
from keras.callbacks import ModelCheckpoint
from keras.backend import tensorflow_backend
import yaml
import time
import argparse
# from cnn_ctc import Am, am_hparams
from gru_ctc import Am, am_hparams

# 自动分配GPU内存,以防止Gpu内存不够的情况
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tensorflow_backend.set_session(tf.Session(config=config))

# 动态参数
parser = argparse.ArgumentParser()
parser.add_argument(
    '-c', '--config', type=str, default='./config/conf_am.yaml')
parser.add_argument(
    '-s', '--save_path', type=str, default='save_models/'+time.strftime("%m-%d-%H-%M-%S")+"-logs")
cmd_args = parser.parse_args()

## 加载整个训练必要的config
f = open(cmd_args.config, 'r', encoding='utf-8')
parms_dict = yaml.load(f, Loader=yaml.FullLoader)
f.close()

## 训练数据参数
data_args = dic2args(parms_dict['data'])
train_data = get_data(data_args)
batch_num = train_data.batch_num
train_batch = train_data.get_am_batch()

## 准备验证所需数据
validation_data = None
validation_steps = None
if parms_dict['data']["dev_path"]:
    dev_args = dic2args(parms_dict['data'])
    dev_args.data_path = parms_dict['data']["dev_path"]
    dev_data = get_data(dev_args)
    dev_batch = dev_data.get_am_batch()
    validation_data = dev_batch
    validation_steps = 50

## 声学模型训
am_args = dic2args(parms_dict['model'])
am_args.vocab_size = len(train_data.am_vocab)
am = Am(am_args) 

## 训练参数
epochs = parms_dict['train']['epochs']
save_path = parms_dict['train']['save_path'] = cmd_args.save_path
retrain_dir = parms_dict['train']['retrain_dir']

## save vocab and config
os.makedirs(save_path,exist_ok=True)
# 保存vocab
am_vocab = train_data.am_vocab
f = open(os.path.join(save_path,"vocab"),"w")
f.write("\n".join(am_vocab))
f.close()
# 保存config
parms_dict["data"]["am_vocab_file"] = os.path.join(save_path,"vocab")
f = open(os.path.join(save_path,"config.yaml"),"w",encoding='utf-8')
yaml.dump(parms_dict,f)
f.close()

## 是否加载预训练模型
if retrain_dir:
    print('load acoustic model...')
    am.ctc_model.load_weights(retrain_dir)

## checkpoint,保存模型信息
ckpt = "model_{epoch:02d}-{loss:.2f}.h5"
checkpoint = ModelCheckpoint(os.path.join(save_path, ckpt), monitor='val_loss', save_weights_only=True, verbose=1, save_best_only=False)

## 开始训练
am.ctc_model.fit_generator(
    train_batch, # 打包好的迭代型训练数据
    steps_per_epoch=batch_num, # 一个epoch训练多少个batch
    epochs=epochs, # 训练多少个epoch
    callbacks=[checkpoint], # 保存的model形式
    workers=1, 
    use_multiprocessing=False, 
    validation_data=validation_data, 
    validation_steps=validation_steps)

## 保存最后一个训练epoch的模型
am.ctc_model.save_weights(os.path.join(save_path,'model_'+str(epochs).zfill(2)+'.h5'))

测试

python test_am.py

import numpy as np
from utils_am import decode_ctc, compute_mfcc
from gru_ctc import Am, am_hparams
import math
import argparse

# 动态参数
parser = argparse.ArgumentParser()
parser.add_argument(
    '-v', '--vocab_file', type=str, default='config/conf_am.yaml')
parser.add_argument(
    '-m', '--model_file', type=str, default='save_models/11-12-14-14-34-logs/model_05.h5')
parser.add_argument(
    '-w', '--wav_file', type=str, default='test.wav')
cmd_args = parser.parse_args()

# 加载vocab
am_vocab = []
for s in open(cmd_args.vocab_file):
    am_vocab.append(s.strip())

## 加载语音识别模型
am_args = am_hparams()
am_args.vocab_size = len(am_vocab)
am = Am(am_args) 
print('loading acoustic model...')
am.ctc_model.load_weights(cmd_args.model_file)

# 生成mfcc并扩展成四维矩阵输入到模型中
mfcc = compute_mfcc(cmd_args.wav_file)
x = np.zeros((1,8*math.ceil(mfcc.shape[0]/8),mfcc.shape[1],1))
x[0,:mfcc.shape[0],:,0] = mfcc

# 预测结果
result = am.model.predict(x, steps=1)
_, text = decode_ctc(result, am_vocab)
text = ' '.join(text)
print('预测结果:', text)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值