声学模型训练(前向-后向算法)
前文讲述了语音识别声学模型训练算法,主要基于Viterbi-EM算法来估计模型中参数,但是该方法对于计算语料中帧对应状态的弧号存在计算复杂度指数级增加的问题,为解决上述问题,有学者提出用前向后向算法来估计模型中参数,其可以解决复杂度指数级增长的问题,主要理论及工程实现如下。
本文语音识别算法主要参考哥伦比亚大学语音识别课程提供的源码
首先给出整体模型训练流程,其代码如下:
#ifndef NO_MAIN_LOOP
void main_loop(const char** argv) {
map<string, string> params;
process_cmd_line(argv, params);
Lab2FbMain mainObj(params);
GmmStats gmmStats(mainObj.get_gmm_set(), params);
while (mainObj.init_iter()) {
gmmStats.clear();
while (mainObj.init_utt()) {
double logProb = forward_backward(
mainObj.get_graph(), mainObj.get_gmm_probs(), mainObj.get_chart(),
mainObj.get_gmm_counts(), mainObj.get_trans_counts());
mainObj.finish_utt(logProb);
gmmStats.update(mainObj.get_gmm_counts(), mainObj.get_feats());
}
mainObj.finish_iter();
gmmStats.reestimate();
}
mainObj.finish();
}
#endif
上述代码给出了整体训练流程,现对其进行逐个讲解:
1.代码初始化
map<string, string> params;
process_cmd_line(argv, params);
上述代码主要读取输入超参数,并对其进行处理,最后结果如下所示:
上述参数以此表示输入语音位置,chart图,解码图,输入gmm模型,迭代次数以及经过更新后gmm模型参数。
2. 前后向算法初始化
Lab2FbMain mainObj(params); 表示实例化出前后向算法,其中前后向算法构造函数输入为:
Lab2FbMain::Lab2FbMain(const map<string, string>& params)
: m_params(params),
m_frontEnd(m_params),
m_gmmSet(get_required_string_param(m_params, "in_gmm")),
m_outGmmFile(get_required_string_param(m_params, "out_gmm")),
m_transCountsFile(get_string_param(params, "trans_counts")),
m_iterCnt(get_int_param(m_params, "iters", 1)),
m_iterIdx(1),
m_totFrmCnt(0),
m_totLogProb(0.0) {
if (!m_transCountsFile.empty()) {
m_graph.read_word_sym_table(
get_required_string_param(params, "trans_syms"));
}
}
上述构造函数主要对传入参数进行处理并对其进行赋值操作;
m_frontEnd(m_params),主要用于对语音数据进行特征提取;
m_gmmSet(get_required_string_param(m_params, "in_gmm"))用于读取原始GMM参数,其中GMM可以对其进行初始化,也可以不使用初始化操作,通常情况下,我们对GMM中均值与方差进行初始化有以下策略:因为语音识别中常用对角矩阵表示均值与方差矩阵,然而方差矩阵对角线元素均为非负,均值元素可正可负,因此可以尝试使用单位对角阵初始化方差矩阵,使用零对角阵初始化均值矩阵,至于为什么使用不同的元素初始化该参数:个人理解主要是调试过程中便于对其进行区分且符合上述对角矩阵理论部分;
其他参数依次表示经过参数训练后gmm模型参数存储位置、转移概率存储位置、迭代次数、总帧数以及总似然值,至此实例化对象参数初始化完毕。
3.gmm模型状态初始化
GmmStats gmmStats(mainObj.get_gmm_set(), params);类主要对前向后向算法实例化对象与函数参数作为输入,将前后向算法结果输入至GmmStats类实例化对象gmmStats中,因为该类对于状态计算很重要,因此将该类主要函数与参数展示如下:
class GmmStats {
public:
GmmStats(GmmSet& gmmSet, const map<string, string>& params = ParamsType());
void clear();
double update(const vector<GmmCount>& gmmCountList,
const matrix<double>& feats);
double add_gmm_count(unsigned gmmIdx, double posterior,
const vector<double>& feats);
void reestimate() const;
private:
map<string, string> m_params;
/** Reference to associated GmmSet. **/
GmmSet& m_gmmSet;
/** Total counts of each Gaussian. **/
vector<double> m_gaussCounts;
/** First-order stats for each dim of each Gaussian. **/
matrix<double> m_gaussStats1;
/** Second-order stats for each dim of each Gaussian. **/
matrix<double> m_gaussStats2;
};
前文关于声学模型训练部分对该类中三个主要参数进行了说明,其分别统计语料库中所有状态对应出现次数,均值统计以及方差参数进行统计,其中对gmmStats实例化对象初始化如下:
GmmStats::GmmStats(GmmSet& gmmSet, const map<string, string>& params)
: m_params(params),
m_gmmSet(gmmSet),
m_gaussCounts(m_gmmSet.get_gaussian_count()),
m_gaussStats1(m_gmmSet.get_gaussian_count(), m_gmmSet.get_dim_count()),
m_gaussStats2(m_gmmSet.get_gaussian_count(), m_gmmSet.get_dim_count()) {
clear();
}
void GmmStats::clear() {
fill(m_gaussCounts.begin(), m_gaussCounts.end(), 0.0);
fill(m_gaussStats1.data().begin(), m_gaussStats1.data().end(), 0.0);
fill(m_gaussStats2.data().begin(), m_gaussStats2.data().end(), 0.0);
}
上述gmm状态初始化主要用输入gmm模型统计量结果对其进行复制与初始化。
4.开始迭代
用while循环开始迭代更新gmm模型中参数,其初始化迭代代码如下:
bool Lab2FbMain::init_iter() {
if (m_iterIdx > m_iterCnt) return false;
m_transCounts.clear();
m_audioStrm.clear();
m_audioStrm.open(get_required_string_param(m_params, "audio_file").c_str());
m_graphStrm.clear();
m_graphStrm.open(get_required_string_param(m_params, "graph_file").c_str());
m_totFrmCnt = 0;
m_totLogProb = 0.0;
return true;
}
该部分主要是对音频文件、解码图文件以及一些超参数进行初始化操作。
5.开始处理语料
同理用while循环遍历语料,并对其进行处理,其初始化操作代码如下:
bool Lab2FbMain::init_utt() {
if (m_audioStrm.peek() == EOF) return false;
m_idStr = read_float_matrix(m_audioStrm, m_inAudio);
cout << "Processing utterance ID: " << m_idStr << endl;
m_frontEnd.get_feats(m_inAudio, m_feats);
if (m_feats.size2() != m_gmmSet.get_dim_count())
throw runtime_error("Mismatch in GMM and feat dim.");
if (m_graphStrm.peek() == EOF)
throw runtime_error("Mismatch in number of audio files and FSM's.");
m_graph.read(m_graphStrm, m_idStr);
if (m_graph.get_gmm_count() > m_gmmSet.get_gmm_count())
throw runtime_error(
"Mismatch in number of GMM's between "
"FSM and GmmSet.");
m_gmmSet.calc_gmm_probs(m_feats, m_gmmProbs);
m_chart.resize(m_feats.size1() + 1, m_graph.get_state_count());
m_chart.clear();
if (m_graph.get_start_state() < 0)
throw runtime_error("Graph has no start state.");
m_gmmCountList.clear();
return true;
}
上述代码核心部分前文已经有介绍,主要是用于语料库中语句的特征提取、解码图读取、计算当前帧属于各个状态的概率密度函数以及对chart格子图进行初始化操作,不懂的可以看前文博客对计算pdf部分与初始化格子图为什么多一帧的介绍。
6.前向后向算法
接下来则是前向后向算法核心部分,先将算法代码列出如下所示:
double forward_backward(const Graph& graph, const matrix<double>& gmmProbs,
matrix<FbCell>& chart, vector<GmmCount>& gmmCountList,
map<int, double>& transCounts) {
int frmCnt = chart.size1() - 1;
int stateCnt = chart.size2();
{
for (int frmIdx = 0; frmIdx < (int)chart.size1(); ++frmIdx) {
for (int stateIdx = 0; stateIdx < (int)chart.size2(); ++stateIdx) {
chart(frmIdx, stateIdx).set_forw_log_prob(g_zeroLogProb);
chart(frmIdx, stateIdx).set_back_log_prob(g_zeroLogProb);
}
}
}
int startState = graph.get_start_state();
chart(0, startState).set_forw_log_prob(0);
for (int frmIdx = 1; frmIdx <= frmCnt; ++frmIdx) {
for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {
int arcCnt = graph.get_arc_count(stateIdx);
int arcId = graph.get_first_arc_id(stateIdx);
for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) {
Arc arc;
arcId = graph.get_arc(arcId, arc);
int dstState = arc.get_dst_state();
//arc.get_log_prob(),chart(frmIdx - 1, stateIdx).get_forw_log_prob(),
//gmmProbs(frmIdx - 1, arc.get_gmm())三者分别表示为状态转移概率,
//子图初始概率以及状态发射概率
double logProb = arc.get_log_prob() +
chart(frmIdx - 1, stateIdx).get_forw_log_prob() +
gmmProbs(frmIdx - 1, arc.get_gmm());
logProb = add_log_probs(vector<double>{
logProb, chart(frmIdx, dstState).get_forw_log_prob()});
chart(frmIdx, dstState).set_forw_log_prob(logProb);
}
}
}
//for (int frmidx = 0; frmidx <= frmCnt; ++frmidx) {
// for (int srcidx = 0; srcidx < stateCnt; ++srcidx) {
// cout << format(" %d") % chart(frmidx, srcidx).get_forw_log_prob();
// }
// cout << endl;
//}
//得到概率最大的终止状态的似然值及其终止状态序号;
double uttLogProb = init_backward_pass(graph, chart);
if (uttLogProb == g_zeroLogProb) return uttLogProb;
for (int frmIdx = frmCnt - 1; frmIdx >= 0; --frmIdx) {
for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {
int arcCnt = graph.get_arc_count(stateIdx);
int arcId = graph.get_first_arc_id(stateIdx);
for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) {
Arc arc;
arcId = graph.get_arc(arcId, arc);
int dstState = arc.get_dst_state();
double logProb = arc.get_log_prob() + gmmProbs(frmIdx, arc.get_gmm()) +
chart(frmIdx + 1, dstState).get_back_log_prob();
// NOTE!!! They are log prob but not regular prob, so use add_log_probs
// but not +.
// logProb += chart(frmIdx, stateIdx).get_back_log_prob();
logProb = add_log_probs(vector<double>{
logProb, chart(frmIdx, stateIdx).get_back_log_prob()});
chart(frmIdx, stateIdx).set_back_log_prob(logProb);
}
}
}
//for (int frmIdx = 0; frmIdx <= frmCnt; ++frmIdx) {
// for (int srcIdx = 0; srcIdx < stateCnt; ++srcIdx) {
// cout << format(" %d") % chart(frmIdx, srcIdx).get_back_log_prob();
// }
// cout << endl;
//}
for (int frmIdx = frmCnt; frmIdx > 0; --frmIdx) {
for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {
int arcCnt = graph.get_arc_count(stateIdx);
int arcId = graph.get_first_arc_id(stateIdx);
for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) {
Arc arc;
arcId = graph.get_arc(arcId, arc);
int dstState = arc.get_dst_state();
//logProb表示任意时刻到达某一帧某条弧上的概率,其采用前向后向算法进行计算;
double logProb =
chart(frmIdx - 1, stateIdx).get_forw_log_prob() + // alpha_t-1_i
arc.get_log_prob() + // a_i_j
gmmProbs(frmIdx - 1, arc.get_gmm()) + // b_j_(ot)
chart(frmIdx, dstState).get_back_log_prob(); // beta_t_j
//exp(logProb - uttLogProb)表示状态转移至终止状态时概率,即为弧上概率;
gmmCountList.push_back(
GmmCount(arc.get_gmm(), frmIdx - 1, exp(logProb - uttLogProb)));
}
}
}
return uttLogProb;
}
笔者对于前向后向算法的理解如下:
(1)该算法可以称之为评估问题,即已知声学模型参数(gmm参数)以及其观测序列(解码图),如何基于此计算该模型产生该序列的产产生的概率,即为对该声学模型结果进行打分;
(2)此时的chart格子图与之前Viterbi-EM不一致,其每个元素对应的数据类型为自定义格式类型,主要用于存储以及读取前向概率与后向概率,该变量具体参数如下:
class FbCell {
public:
FbCell() : m_forwLogProb(g_zeroLogProb), m_backLogProb(g_zeroLogProb) {}
explicit FbCell(int)
: m_forwLogProb(g_zeroLogProb), m_backLogProb(g_zeroLogProb) {}
void set_forw_log_prob(double logProb) { m_forwLogProb = logProb; }
void set_back_log_prob(double logProb) { m_backLogProb = logProb; }
double get_forw_log_prob() const { return m_forwLogProb; }
double get_back_log_prob() const { return m_backLogProb; }
void printLogprobs() {
cout << m_forwLogProb << " "
<< m_backLogProb << endl;
}
private:
double m_forwLogProb;
double m_backLogProb;
};
(3)add_log_probs()是存储前向参数与向参数的核心所在,虽然笔者在前期讲解过该参数的具体实现方法,其主要用于计算到此状态时最大的似然概率,不管是前向计算还是后向计算,均是如此,计算的结果均为到此状态时最大的概率,其存储的分别为前向概率与后巷概率,其变量类型如上所述;
(4)实际上前向算法计算至最后与前文Viterbi-EM算法是一致的,但是降低模型的复杂度进而引入了后向算法
7.终止语料读取
通过mainObj.finish_utt(logProb);函数终止前后向算法的语料计算,其具体代码实现如下:
void Lab2FbMain::finish_utt(double logProb) {
m_totFrmCnt += m_feats.size1();
m_totLogProb += logProb;
double minPosterior = get_float_param(m_params, "min_posterior", 0.001);
//m_gmmCountList存储结果为所有帧对应弧,包括弧序号,弧所属帧以及对应转移概率,转移概率由前后向算法进行计算;
if (minPosterior > 0.0) {
m_gmmCountListThresh.clear();
for (int cntIdx = 0; cntIdx < (int)m_gmmCountList.size(); ++cntIdx) {
if (m_gmmCountList[cntIdx].get_count() >= minPosterior) {
m_gmmCountListThresh.push_back(m_gmmCountList[cntIdx]);
}
//m_gmmCountListThresh过滤掉m_gmmCountList中转移概率太小的弧单元,即为对弧进行剪枝;
}
m_gmmCountList.swap(m_gmmCountListThresh);
}
//sort(m_gmmCountList.begin(), m_gmmCountList.end());
string chartFile = get_string_param(m_params, "chart_file");
if (!chartFile.empty()) {
ofstream chartStrm(chartFile.c_str());
int frmCnt = m_feats.size1();
int stateCnt = m_graph.get_state_count();
matrix<double> matForwProbs(frmCnt + 1, stateCnt);
matrix<double> matBackProbs(frmCnt + 1, stateCnt);
for (int frmIdx = 0; frmIdx <= frmCnt; ++frmIdx) {
for (int srcIdx = 0; srcIdx < stateCnt; ++srcIdx) {
matForwProbs(frmIdx, srcIdx) =
m_chart(frmIdx, srcIdx).get_forw_log_prob();
matBackProbs(frmIdx, srcIdx) =
m_chart(frmIdx, srcIdx).get_back_log_prob();
}
}
write_float_matrix(chartStrm, matForwProbs, m_idStr + "_forw");
write_float_matrix(chartStrm, matBackProbs, m_idStr + "_back");
matrix<double> matPost(frmCnt, m_gmmSet.get_gmm_count());
matPost.clear();
int gmmCountCnt = m_gmmCountList.size();
for (int cntIdx = 0; cntIdx < gmmCountCnt; ++cntIdx) {
const GmmCount& gmmCount = m_gmmCountList[cntIdx];
matPost(gmmCount.get_frame_index(), gmmCount.get_gmm_index()) +=
gmmCount.get_count();
}
write_float_matrix(chartStrm, matPost, m_idStr + "_post");
chartStrm.close();
}
//for (int i = 0; i < m_gmmCountList.size(); i++) {
// m_gmmCountList[i].printGmmCount();
//}
}
必须说明的是为了剪枝gmm模型中弧上概率较小的部分,设置阈值(0.001)来控制gmm模型的权重,对于权重较小的gmm模型不计算统计量,这样可以大幅度降低模型中的参数数量,笔者曾对此进行测试过(未剪枝是2516条弧,剪枝后仅为168条弧),这样可以最大幅度降低模型中参数数量,而且对最后结果影响很小。
最终将前向概率与后向概率以及gmm模型的权重写入值chart图中,其中gmm权重在代码中表示为后验概率,最后后验概率表示如下图所示:
从上图可知,gmm权重大部分为0,这样大大减少了模型计算量且便于参数计算,非常值得推荐。
8. 声学模型参数赋值
用gmmStats.update(mainObj.get_gmm_counts(), mainObj.get_feats());函数基于特征用gmm统计量对gmm模型进行状态个数、均值以及方差进行复制,前文声学模型Viterbi-EM对此进行详细的说明,不懂的读者可以参考下。
9.终止迭代
mainObj.finish_iter();函数控制迭代次数,其具体实现如下:
void Lab2FbMain::finish_iter() {
m_audioStrm.close();
m_graphStrm.close();
cout << format("Iteration %d: %.6f logprob/frame (%d frames)") % m_iterIdx %
(m_totFrmCnt ? m_totLogProb / m_totFrmCnt : 0.0) % m_totFrmCnt
<< endl;
++m_iterIdx;
}
该函数主要用于打印总体似然概率以及文件流关闭。
10.参数更新
前文对此亦进行介绍,现给出参数更新代码如下:
void GmmStats::reestimate() const {
int gaussCnt = m_gmmSet.get_gaussian_count();
int dimCnt = m_gmmSet.get_dim_count();
double occupancy, mean, var;
for (int gaussIdx = 0; gaussIdx < gaussCnt; ++gaussIdx) {
occupancy = m_gaussCounts[gaussIdx];
for (int dimIdx = 0; dimIdx < dimCnt; ++dimIdx) {
//均值与方差重新估计,
mean = m_gaussStats1(gaussIdx, dimIdx) / occupancy;
var = m_gaussStats2(gaussIdx, dimIdx) / occupancy - mean * mean;
m_gmmSet.set_gaussian_mean(gaussIdx, dimIdx, mean);
m_gmmSet.set_gaussian_var(gaussIdx, dimIdx, var);
}
}
}
11.终止参数估计
最终将更新的参数存储至输出的gmm模型中,其代码如下:
void Lab2FbMain::finish() {
m_gmmSet.write(m_outGmmFile);
if (!m_transCountsFile.empty()) {
ofstream countStrm(m_transCountsFile.c_str());
for (map<int, double>::const_iterator elemIter = m_transCounts.begin();
elemIter != m_transCounts.end(); ++elemIter)
countStrm << format("%s %.3f\n") %
m_graph.get_word_sym_table().get_str(elemIter->first) %
elemIter->second;
countStrm.close();
}
}
前文对此都介绍过,本文对此进行代码说明,
至此语音识别基于前向后向算法介绍完毕。