区分性训练和mmi


搜集资料的思路:mmi -> DT -> mle -> ce -> 熵 -> 互信息

互信息

首先连接机器学习中的熵、条件熵、相对熵(KL散度)和交叉熵的概念:传送门

信息熵是衡量随机变量分布的混乱程度,是随机分布各事件发生的信息量的期望值,随机变量的取值个数越多,状态数也就越多,信息熵就越大,混乱程度就越大。

当随机分布为均匀分布时,熵最大;信息熵推广到多维领域,则可得到联合信息熵;条件熵表示的是在 X 给定条件下,Y 的条件概率分布的熵对 X的期望。

相对熵可以用来衡量两个概率分布之间的差异。

交叉熵可以来衡量在给定的真实分布下,使用非真实分布所指定的策略消除系统的不确定性所需要付出的努力的大小。

互信息(mutual information):

在这里插入图片描述

mmi准则实际上就是最大化互信息的缩写

最大似然估计MLE(缺点)

最大似然估计,即假设一种分布,用现在已知的数据去估计这种分布的参数。

然后看交叉熵,我们用估计的分布算出一个交叉熵,然后让交叉熵尽可能小达到估计真实的分布。因此我们之前的训练是属于MLE,用的是交叉熵。交叉熵也就是MLE的一种实现,同样如果是KL散度,也是属于MLE的实现,正如我们前面说的这样。

我们语音识别的目的是最大化P(W∣O). w是words ,o是指观测序列

P ( W ∣ O ) = P ( O ∣ W ) P ( W ) P ( O ) P(W|O)=\frac{P(O|W)P(W)}{P(O)} P(WO)=P(O)P(OW)P(W),其中 P ( O ) P(O) P(O)对于每个任务都是一样的, P ( W ) P(W) P(W) 是由语言模型确定, P ( O ∣ W ) P(O∣W) P(OW)是由声学模型确定,也称之为后验概率。

HMM也就是对 P ( O ∣ W ) P(O∣W) P(OW)进行的建模,我们所用的神经网络实际上也就是学习hmm的发射矩阵。

MLE缺点:

  • 我们所学习模拟的分布是已知的。要求模型假设必须正确。
  • 训练时数据应趋向于无穷多。
  • 在解码时语言模型是趋向于真实的语言分布。

对于第一点,我们是通过GMM去模拟语音的真实分布,因为我们认为GMM可以模拟出任意一种分布。然后在此基础上去对齐训练。但是如果GMM训练的不到位呢。第二点第三点就不用再说了。

于是提出了区分性训练

区分性训练DT和最大互信息MMI

区分性训练

于是对于这些问题,就有了区分性训练,区分性训练实际上就是希望通过设置一个目标函数达到奖励正确的同时处罚错误的这样一个目的,来进行训练。

它有几种实现方法,其中一种是MMI,最大化互信息,其他的还有BMMI/MPE/sMBR之类的。

我们前面说到,互信息是描述两个随机变量的关联程度,于是在这里就是描述观测序列和文本的关联程度。

再来看一下互信息公式,这里只介绍一下主要的公式, I ( x , y ) = H ( x ) − H ( x ∣ y ) I(x,y)=H(x)−H(x∣y) I(x,y)=H(x)H(xy)。因此最大化互信息,就等于最小化条件熵 H ( x ∣ y ) H(x∣y) H(xy)

MMI

公式:
在这里插入图片描述
MMI公式本身它可以看成,正确路径得分与所有路径得分的比值。当正确路径得分提升的同时,错误路径得分会减少,因此是一种区分性训练。

实际上我们就可以看成

在这里插入图片描述
也就是我们是以 P ( W ∣ O ) P(W|O) P(WO)作为目标函数,而不是以 P ( O ∣ W ) P(O|W) P(OW)作为目标函数了。可以看到,实际上,MMI把语言模型的东西也考虑进去了。这样做相对于MLE的好处就是,即使假设的分布不是很好,得到的结果也不会太差。

SDT(MMI/BMMI/MPE/sMBR)的详细公式推理可以参考一下博客:传送门

区分性训练缺点

训练的数据泛化能力不强。 我的理解是,我们把一个正确的路径学的太多了,导致相对于这个训练集的错误路径的分数太低太低,因此为了增强它的泛化能力,我们适当增加错误路径的得分。或者说让声学模型在其中更具主导性,而不是一味地让语言模型进行主导。

MMI训练过程
我们在对一些参数进行更新的时候,需要用到的是前向后向算法,并且需要对分子和分母分别进行前向后向算法。这个是hmm中使用的一种方法。

Lattice

对于分子来说,这个算法是可行的,但对于分母来说,这个算法的计算量就太大了。所以为了计算可行性,我们使用lattice来进行计算,lattice实际上就是词格,词网络。我们通常使用的网络是状态级别的网络,而这里使用的词网络,这就是我们在训练前要进行对齐lats。然后我们为了方便统计信息,我们在每个词节点中加入状态信息。
所以,我们在训练chain-model的时候,使用的是wer,而不是cer。同时,值得注意的是,我们在实际使用的时候,wer也是更值得关注一些。
当然,对分母的优化只有这一点是远远不够的,在lf-mmi中会大量提及对分母的优化,包括hmm的优化,解码图的优化,以及各种tricks。这里先只提及从状态级别的转向lattice级别的优化。

在这里插入图片描述
注意:这里要强调一点,正确的路径可能不止一条,因为可能会有多音字的情况。

对于它的训练:

  • 我们还是使用MLE(CE)准则训练的模型去进行对齐得到lats。
  • 一般来说,在整个训练过程中这些lats不会改变的,但也有随着训练过程而变化的情况,这样得到的效果也是不错的。
  • 解码图,每个分子的解码图是根据当前句子定的,但分母的解码图都是一样的。
  • 但实际上我们在kaldi训练的过程中不是这样的,kaldi中实现的是lf-mmi

MMI的问题:

对于MMI的一个问题就是,训练的数据泛化能力不强。我的理解是,我们把一个正确的路径学的太多了,导致相对于这个训练集的错误路径的分数太低太低,因此为了增强它的泛化能力,我们适当增加错误路径的得分。或者说让声学模型在其中更具主导性,而不是一味地让语言模型进行主导。
因此有几种解决方法,这里只列举几种:

  • 声学参数k,增加声学参数的影响
  • lat使用一个简单的语言模型,一般来说,在分母的解码图的语言模型实际上使用的是一个word-level的bigram
  • boosted mmi
    在这里插入图片描述

我们可以看到,和MMI相比,BMMI在分母上增加了一项, e x p ( − b A ( s , s r ) ) exp(-bA(s,sr)) exp(bA(s,sr)) 其中A(s,sr)表示的是用来描述s和sr的准确性的一个参数。

  • 包括在lf-mmi中也有一些方法去解决泛化能力

MMI代码分析

输入数据的介绍,以及特征的转化

train_mmi.sh:
输入输出文件:
data/train_si84 输入数据
data/lang 词典等信息
exp/tri2b_ali_si84 对齐
exp/tri2b_denlats_si84 lattice
exp/tri2b_mmi 输出文件

Usage: steps/train_mmi_sgmm2.sh <data> <lang> <ali> <denlats> <exp>
e.g.: steps/train_mmi_sgmm2.sh data/train_si84 
data/lang exp/tri2b_ali_si84 exp/tri2b_denlats_si84 exp/tri2b_mmi"

比较了两个音子标号的文本是否是一样的,不一样的话就直接报错停止
每个音子都唯一的对应于一个整数

utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1;
cp $lang/phones.txt $dir || exit 1;

data/feats.scp 每句话所对应的提取出来的二进制存储的特征他的存放位置
alidir/{tree,final.mdl,ali.1.gz} 对齐时候的决策树,模型,对齐的信息。
denlatdir/lat.1.gz 存放词图

for f in $data/feats.scp $alidir/{tree,final.mdl,ali.1.gz} $denlatdir/lat.1.gz; do
  [ ! -f $f ] && echo "$0: no such file $f" && exit 1;
done

接下就是对于特征的转换。
final.mat 是用来进行特征转换的,
transform-feats,使用transform来进行特征转换,为了解码调用。之后可以对该步生成的ark文件,进行解码的操作,得到一个lattice文件。可以参考博客:传送门
transform-feats final.mat ark:splice.ark ark:transform.ark


# Set up features

if [ -f $alidir/final.mat ]; then feat_type=lda; else feat_type=delta; fi
echo "$0: feature type is $feat_type"

case $feat_type in
  delta) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas ark:- ark:- |";;
  lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |"
    cp $alidir/final.mat $dir    
    ;;
  *) echo "Invalid feature type $feat_type" && exit 1;
esac

if [ ! -z "$transform_dir" ]; then
  echo "$0: using transforms from $ "
  [ ! -f $transform_dir/trans.1 ] && echo "$0: no such file $transform_dir/trans.1" \
    && exit 1;
  feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$transform_dir/trans.JOB ark:- ark:- |"
else
  echo "$0: no fMLLR transforms."
fi

lattice boost

lats="ark:gunzip -c $denlatdir/lat.JOB.gz|"
if [[ "$boost" != "0.0" && "$boost" != 0 ]]; then
  lats="$lats lattice-boost-ali --b=$boost --silence-phones=$silphonelist $alidir/final.mdl ark:- 'ark,s,cs:gunzip -c $alidir/ali.JOB.gz|' ark:- |"
fi

lattice-boost-ali.cc
通过lattice中每个弧上的帧错误来提高图的相似性(降低图的成本)。

有助于鉴别训练,例如Boost MMI。它主要是修改lattice
此版本采用对齐形式的引用。
需要模型(只是transition)来将pdf_id转换为音子
使用–silence phones选项,lattice中出现的这些静音音子始终被指定为零错误。或者使用–max silence error选项,最多每帧的错误计数(–max silence error=1相当于未指定–silence phones)。
在这里插入图片描述
其中这个文件主要是调用 LatticeBoost()函数对lattice进行了修改

LatticeBoost()
传送门

     po.Register("b", &b, 
                 "Boosting factor (more -> more boosting of errors / larger margin)");
     po.Register("max-silence", &max_silence_error,
                 "Maximum error assigned to silence phones [c.f. --silence-phones option]."
                 "0.0 -> original BMMI paper, 1.0 -> no special silence treatment.");
     po.Register("silence-phones", &silence_phones_str,
                 "Colon-separated list of integer id's of silence phones, e.g. 46:47");
bool LatticeBoost	(	const TransitionModel & 	trans,
const std::vector< int32 > & 	alignment,
const std::vector< int32 > & 	silence_phones,
BaseFloat 	b,
BaseFloat 	max_silence_error,
Lattice * 	lat 
)	

提高 LM语言模型的概率,通过b,也就是将每帧的错误的数量乘以b,然后添加到lattice图中弧的代价中去

如果特定帧上的特定transion_id所对应的音子与该帧alignment对齐的音子不匹配的话,则存在帧错误。

参数中TransitionModel 的作用是将 在lattice输入端中 transition-ids 映射成 phones;

在“silence_phones”中出现的phones被特殊处理,我们用minimum of f 或max_silence_error来替换一个帧的帧错误f(0或1)。

对于正常情况,max_silence_error将为零。成功时返回true,如果存在某种不匹配,则返回false。输入时,silence_phones必须分类和唯一。

                 {
   TopSortLatticeIfNeeded(lat);
 
   // get all stored properties (test==false means don't test if not known).
   uint64 props = lat->Properties(fst::kFstProperties,
                                  false);
   //对静音的音子要确保里面的音子没有重复
   KALDI_ASSERT(IsSortedAndUniq(silence_phones));
   //max_silence_erroe的值为0-1之间
   KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
   vector<int32> state_times;
   int32 num_states = lat->NumStates();
   //num_frames 有多少帧,state_times数组指的是每个state对应的是哪一帧
   int32 num_frames = LatticeStateTimes(*lat, &state_times);
   KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
   for (int32 state = 0; state < num_states; state++) {
     int32 cur_time = state_times[state];
     for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
          aiter.Next()) {
       LatticeArc arc = aiter.Value();
       if (arc.ilabel != 0) {  // Non-epsilon arc
         if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
           KALDI_WARN << "Lattice has out-of-range transition-ids: "
                      << "lattice/model mismatch?";
           return false;
         }
         //phone 是指这条弧对应的输入transition_id所对应的音子
         //ref_phone 是指当前state->当前的帧cur_time->transition_id->phone音子
         int32 phone = trans.TransitionIdToPhone(arc.ilabel),
             ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
         BaseFloat frame_error;
         //如果弧输入对应的音子和当前state对应的音子一样的话,就是正确的,error=0
         if (phone == ref_phone) {
           frame_error = 0.0;
         } else { // an error...
         //接下就是判断是否是静音音子,如果损失静音音子错误的指就是预设的max_silence_error,否则就为1
           if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
             frame_error = max_silence_error;
           else
             frame_error = 1.0;
         }
         BaseFloat delta_cost = -b * frame_error; // negative cost if
         // frame is wrong, to boost likelihood of arcs with errors on them.
         // Add this cost to the graph part.
         arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
         aiter.SetValue(arc);
       }
     }
   }
   // All we changed is the weights, so any properties that were
   // known before, are still known, except for whether or not the
   // lattice was weighted.
   lat->SetProperties(props,
                      ~(fst::kWeighted|fst::kUnweighted));
 
   return true;
 }

接下是MMI的核心部分:

  if [ $stage -le $x ]; then
    $cmd JOB=1:$nj  \
      //test命令用于检查某个条件是否成立 -s 文件名,如果文件存在且至少有一个字符则为真
      test -s $dir/den_acc.$x.JOB.gz -a -s $dir/num_acc.$x.JOB.gz '||' \
      //使用一个新的模型去更改lattice上的声学分数。
      sgmm2-rescore-lattice --speedup=true "$gselect_opt" $spkvecs_opt $dir/$x.mdl "$lats" "$feats" ark:- \| \
      lattice-to-post --acoustic-scale=$acwt ark:- ark:- \| \
      sum-post --drop-frames=$drop_frames --merge=$cancel --scale1=-1 \
      ark:- "ark,s,cs:gunzip -c $alidir/ali.JOB.gz | ali-to-post ark:- ark:- |" ark:- \| \
      sgmm2-acc-stats2 "$gselect_opt" $spkvecs_opt $dir/$x.mdl "$feats" ark,s,cs:- \
      "|gzip -c >$dir/num_acc.$x.JOB.gz" "|gzip -c >$dir/den_acc.$x.JOB.gz" || exit 1;

    n=`echo $dir/{num,den}_acc.$x.*.gz | wc -w`;
    [ "$n" -ne $[$nj*2] ] && \
      echo "Wrong number of MMI accumulators $n versus 2*$nj" && exit 1;
    num_acc_sum="sgmm2-sum-accs - ";
    den_acc_sum="sgmm2-sum-accs - ";
    for j in `seq $nj`; do 
      num_acc_sum="$num_acc_sum 'gunzip -c $dir/num_acc.$x.$j.gz|'"; 
      den_acc_sum="$den_acc_sum 'gunzip -c $dir/den_acc.$x.$j.gz|'"; 
    done
    $cmd $dir/log/update.$x.log \
     sgmm2-est-ebw $update_opts $dir/$x.mdl "$num_acc_sum |" "$den_acc_sum |" \
      $dir/$[$x+1].mdl || exit 1;
    rm $dir/*_acc.$x.*.gz 
  fi

sgmm2-rescore-lattice.cc :

Replace the acoustic scores on a lattice using a new model
使用一个新的模型去更改lattice上的声学分数。

参数为模型,更改前的lattice,声学特征,更改后的lattice

Usage: sgmm2-rescore-lattice [options] <model-in> <lattice-rspecifier> 
<feature-rspecifier> <lattice-wspecifier>
e.g.: sgmm2-rescore-lattice 1.mdl ark:1.lats scp:trn.scp ark:2.lats
     kaldi::ParseOptions po(usage);
     po.Register("old-acoustic-scale", &old_acoustic_scale,
                 "Add the current acoustic scores with some scale.");
     po.Register("log-prune", &log_prune,
                 "Pruning beam used to reduce number of exp() evaluations.");
     po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors (rspecifier)");
     po.Register("utt2spk", &utt2spk_rspecifier,
                 "rspecifier for utterance to speaker map");
     po.Register("gselect", &gselect_rspecifier,
                 "Precomputed Gaussian indices (rspecifier)预计算高斯指数");
     po.Register("speedup", &speedup,
                 "If true, enable a faster version of the computation that "
                 "saves times when there is only one pdf-id on a single frame "
                 "by only sometimes (randomly) computing the probabilities, and "
                 "then scaling them up to preserve corpus-level diagnostics.");
 std::string model_filename = po.GetArg(1),
         lats_rspecifier = po.GetArg(2),
         feature_rspecifier = po.GetArg(3),
         lats_wspecifier = po.GetArg(4);
 
     AmSgmm2 am_sgmm;
     TransitionModel trans_model;
     {
       bool binary;
       Input ki(model_filename, &binary);
       trans_model.Read(ki.Stream(), binary);
       am_sgmm.Read(ki.Stream(), binary);
     }
 
     RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
     RandomAccessBaseFloatVectorReaderMapped spkvecs_reader(spkvecs_rspecifier,
                                                            utt2spk_rspecifier);
     RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
     // Read as compact lattice
     SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
     // Write as compact lattice.
     CompactLatticeWriter compact_lattice_writer(lats_wspecifier);
 
     int32 num_done = 0, num_err = 0;
     for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
       std::string utt = compact_lattice_reader.Key();
       //查找这句所对应的特征
       if (!feature_reader.HasKey(utt)) {
         KALDI_WARN << "No feature found for utterance " << utt;
         num_err++;
         continue;
       }
 
       CompactLattice clat = compact_lattice_reader.Value();
       compact_lattice_reader.FreeCurrent();
       if (old_acoustic_scale != 1.0)
         //通过old_acoustic_scale这个声学尺度,对clat这个lattice上的弧的权重以及final_weight进行缩放
         fst::ScaleLattice(fst::AcousticLatticeScale(old_acoustic_scale), &clat);
       //utt这句话所对应的特征
       const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
 
       // Get speaker vectors
       Sgmm2PerSpkDerivedVars spk_vars;
       if (spkvecs_reader.IsOpen()) {
         if (spkvecs_reader.HasKey(utt)) {
           spk_vars.SetSpeakerVector(spkvecs_reader.Value(utt));
           am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
         } else {
           KALDI_WARN << "Cannot find speaker vector for " << utt;
           num_err++;
           continue;
         }
       }  // else spk_vars is "empty"
      //从这里可以看出高斯指数,特征的每一帧对应一个高斯指数
       if (!gselect_reader.HasKey(utt) ||
           gselect_reader.Value(utt).size() != feats.NumRows()) {
         KALDI_WARN << "No Gaussian-selection info available for utterance "
                    << utt << " (or wrong size)";
         num_err++;
         continue;
       }
       const std::vector<std::vector<int32> > &gselect =
           gselect_reader.Value(utt);
 
       DecodableAmSgmm2 sgmm2_decodable(am_sgmm, trans_model, feats,
                                        gselect, log_prune, &spk_vars);
 
       if (!speedup) {
         if (kaldi::RescoreCompactLattice(&sgmm2_decodable, &clat)) {
           compact_lattice_writer.Write(utt, clat);
           num_done++;
         } else num_err++;
       } else {
         BaseFloat speedup_factor = 100.0; 
         if (kaldi::RescoreCompactLatticeSpeedup(trans_model, speedup_factor,
                                                 &sgmm2_decodable,
                                                 &clat)) {
           compact_lattice_writer.Write(utt, clat);
           num_done++;
         } else num_err++;
       }        
     }

最后一次更新时间2019-11-14 ,之后继续更新

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值