前言
虽然还没整体看完kaldi的C++源码,但已经发现kaldi真的是一个复杂而庞大的开源项目,想找到并理解某一种功能的具体实现,比如对齐,解码,GMM迭代… 等等,都绕不开其繁杂而完备的数据结构设计,当然也可能是因为笔者的C++功底一般,还在kaldi源码的苦海中挣扎。
本文是秉着总结,学习和经验分享的目的,利用现有的kaldi数据结构和类型来实现一种功能,整体并不复杂,但有助于对源码的进一步理解,具体来说是编写一个新的kaldi程序。
目标功能
首先定位一下要实现的功能是解码,利用现有的基于nnet3训练的chain模型,这里用到的egs是hi_mia的语音唤醒词识别,比起语音识别,不管是数据还是模型都会更精简一些。一般chain模型的解码都在steps/decode.sh 中实现,具体来讲用到的是 nnet3-latgen-faster 这个程序,他是读取整个音频特征,生成词图再从词图中找一条最优路径作为解码结果。
该解码程序主要是用在测试集上的,需要读取预先生成的特征,具体来讲需要的输入文件是feats.scp,utt2spk以及cmvn.scp,个人认为比较麻烦的是需要预先生成特征保存起来,不方便直接对某一条音频直接进行解码。当然kaldi也有online decode一系列的程序,支持音频文件输入,但他们都是以chunk为单位的流式识别,而现在笔者就是想实现一种读取完整音频文件之后,再进行特征提取和解码的功能。
参考源码:
- kaldi/src/nnet3bin/nnet3-latgen-faster.cc
- kaldi/src/featbin/compute-fbank-feats.cc
自定义程序的路径:kaldi/src/nnet3bin/nnet3-latgen-faster-single.cc
kaldi源码解析
下面直接一部分一部分地对代码进行说明,最后会附上完整代码
首先是inclue:
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "tree/context-dep.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "decoder/decoder-wrappers.h"
#include "nnet3/nnet-am-decodable-simple.h"
#include "nnet3/nnet-utils.h"
#include "base/timer.h"
#include "feat/feature-mfcc.h"
#include "feat/feature-fbank.h"
#include "feat/wave-reader.h"
和 nnet3-latgen-faster.cc 文件相比可以看到多出来了对feat路径的引用,因为我们要进行特征的计算,mfcc或者fbank根据训练情况选一个就行
下面是名称空间的引用:
using namespace kaldi;
using namespace kaldi::nnet3;
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::Fst;
using fst::StdArc;
using fst::VectorFst;
基本上凡是涉及解码都会用到fst的名称空间,这个不是kaldi实现的,而是直接利用了OpenFST这个开源库
下面就是对输入参数的处理:
const char *usage = "Generate lattices using nnet3 with single wav input.\n";
ParseOptions po(usage);
Timer timer;
// MfccOptions mfcc_opts;
FbankOptions fbank_opts;
LatticeFasterDecoderConfig lattice_opts;
NnetSimpleComputationOptions decodable_opts;
std::string word_sym_rxfilename;
BaseFloat min_duration = 0.5;
int32 channel = -1;
int32 online_ivector_period = 0;
// mfcc_opts.Register(&po);
fbank_opts.Register(&po);
lattice_opts.Register(&po);
decodable_opts.Register(&po);
po.Register("word-sym-table", &word_sym_rxfilename, "Symbol table for words.");
po.Register("min-duration", &min_duration, "Minimum duration of input wav (default to 0.5).");
po.Register("channel", &channel, "Channel to extract (-1, 0, 1).");
po.Read(argc, argv);
if (po.NumArgs() != 3) {
po.PrintUsage();
return 1;
}
std::string nnet3_rxfilename = po.GetArg(1);
std::string fst_rxfilename = po.GetArg(2);
std::string wav_rspecifier = po.GetArg(3);
// Mfcc mfcc(mfcc_opts);
Fbank fbank(fbank_opts);
kaldi中对输入参数的读取和注册都是通过 util/parse-options.cc 以及相关头文件实现的,感兴趣可以去看看,这个功能不怎么需要其他数据结构的依赖,可以单独实现。同时可以看到在注册该程序的输入参数之前,ParseOptions 的对象会被用来初始化其他配置类型,比如这里的lattice decoder的配置类和网络计算的配置类,同样和 nnet3-latgen-faster 相比,我们加入了Fbank的配置类。 程序的输入文件通过 ParseOptions 类的 GetArg函数获取,可以看到这里我们设置了3个输入文件,分别是网络模型文件,解码图文件以及音频路径文件,熟悉kaldi的话都清楚这里指的是wav.scp。 最后声名1个Fbank的对象用来执行具体的特征计算,这里参考的是 compute-fbank-feats 这个程序
下面是通过模型文件读取模型:
TransitionModel trans_model;
nnet3::AmNnetSimple am_nnet;
{
bool binary;
Input ki(nnet3_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
nnet3