Kaldi中CMVN处理过程
Author: Xin Pan
Date: 2020.01.03
因为一直好奇CMVN(Cepstral Mean and Variance Normalization,倒谱均值方差归一化)是怎么处理的,正好在服务器安上了gdb就跟着单步调试了一下。
结果我发现compute-cmvn-stats这个命令指示统计了feats中的每个cols的统计值(包括均值和方差,这个cols表示特征的维度,比如16维的feats,cols就是16),但是并没有进行归一化的计算。归一化的计算是通过apply-cmvn进行的。
先看下Kaldi官网对于CMVN的解释,以下内容来自compute-cmvn-stats这个工具的Usage
usage
如果你做了逐句的CMVN就没必要做逐说话人的fMLLR(因为你将会在不同的offset上做fMLLR)。因此这个情况下的说话人信息怎么用呢?这个情况下你就应该让说话人的id(speaker-ids)等于语句id(utterance-ids)。说话人的信息不用等于真正说话的人的数量,这个就是你想要适应的层次。
过程
这次实验的时候我是用aishell 1 的train set 进行的。使用的完整命令如下:
compute-cmvn-stats --spk2utt=ark:data/train/spk2utt scp:data/train/feats.scp ark,scp:/home/panxin/kaldi/egs/aishell/s5/mfcc/cmvn_train.ark,/home/panxin/kaldi/egs/aishell/s5/mfcc/cmvn_train.scp
这次的train特征是16维的mfcc。
feat-to-dim scp:data/train/feats.scp -
16
假设我的spk2utt文件是这样的
S0002 BAC009S0002W0122 BAC009S0002W0123 BAC009S0002W0124
feats.scp是这样的
BAC009S0002W0122 mfcc_noise/raw_mfcc_pitch_train.1.ark:17
BAC009S0002W0123 mfcc_noise/raw_mfcc_pitch_train.1.ark:9751
BAC009S0002W0124 mfcc_noise/raw_mfcc_pitch_train.1.ark:16077
那现在我们进入compute-cmvn-stats.cc的main去看一看,其中对源代码有部分删减
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using kaldi::int32;
/*此处有删减*/
int32 num_done = 0, num_err = 0;
std::string rspecifier = po.GetArg(1); //rspecifier此处就是data/train/feats.scp
std::string wspecifier_or_wxfilename = po.GetArg(2);
RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier);
if (ClassifyWspecifier(wspecifier_or_wxfilename, NULL, NULL, NULL)
!= kNoWspecifier) { // writing to a Table: per-speaker or per-utt CMN/CVN.
std::string wspecifier = wspecifier_or_wxfilename; //wspecifier是保存的位置
DoubleMatrixWriter writer(wspecifier);
if (spk2utt_rspecifier != "") {
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); //此处spk2utt_rspecifier就是ark:data/train/spk2utt
// spk2utt_reader保存的是spk作为key,utt作为value的信息。
RandomAccessBaseFloatMatrixReader feat_reader(rspecifier);
// feat_reader保存的是uttid 和特征值之间的关系。
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
std::string spk = spk2utt_reader.Key(); //spk就是S0002
const std::vector<std::string> &uttlist = spk2utt_reader.Value(); //uttlist 就是[BAC009S0002W0122,BAC009S0002W0123,BAC009S0002W0124]
bool is_init = false;
Matrix<double> stats; //stats就是我们想要的结果了,就是存出去的东西,他是一个Matrix
for (size_t i = 0; i < uttlist.size(); i++) {
std::string utt = uttlist[i]; //utt就是从BAC009S0002W0122迭代到BAC009S0002W0124
if (!feat_reader.HasKey(utt)) {
// 判断feat文件是否有utt(如BAC009S0002W0122)这个key,如果没有则进入这里
KALDI_WARN << "Did not find features for utterance " << utt;
num_err++;
continue;
}
const Matrix<BaseFloat> &feats = feat_reader.Value(utt);
//feats保存的是 utt这个音频的特征值,实实在在的值
if (!is_init) {
InitCmvnStats(feats.NumCols(), &stats); //这个函数用于对stats进行初始化,代码见下文
//这里的feasts.NumCols()是feats的列数,也就是feats的维度数,这里是16
is_init = true;
}
// 在AccCmvnStatsWrapper是实际进行cmvn统计量累积的过程
if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
num_err++;
}
else {
//如果成功处理了一个utt就num_done+1
num_done++;
}
}
if (stats.NumRows() == 0) {
KALDI_WARN << "No stats accumulated for speaker " << spk;
}
else {
// 将stats也就是cmvn结果写进文件中
writer.Write(spk, stats);
}
}
}
else { // per-utterance normalization
SequentialBaseFloatMatrixReader feat_reader(rspecifier);
for (; !feat_reader.Done(); feat_reader.Next()) {
std::string utt = feat_reader.Key();
Matrix<double> stats;
const Matrix<BaseFloat> &feats = feat_reader.Value();
InitCmvnStats(feats.NumCols(), &stats);
if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
num_err++;
continue;
}
writer.Write(feat_reader.Key(), stats);
num_done++;
}
}
}
else { // accumulate global stats
if (spk2utt_rspecifier != "")
KALDI_ERR << "--spk2utt option not compatible with wxfilename as output "
<< "(did you forget ark:?)";
std::string wxfilename = wspecifier_or_wxfilename;
bool is_init = false;
Matrix<double> stats;
SequentialBaseFloatMatrixReader feat_reader(rspecifier);
for (; !feat_reader.Done(); feat_reader.Next()) {
std::string utt = feat_reader.Key();
const Matrix<BaseFloat> &feats = feat_reader.Value();
if (!is_init) {
InitCmvnStats(feats.NumCols(), &stats);
is_init = true;
}
if (!AccCmvnStatsWrapper(utt, feats, &weights_reader, &stats)) {
num_err++;
}
else {
num_done++;
}
}
Matrix<float> stats_float(stats);
WriteKaldiObject(stats_float, wxfilename, binary);
KALDI_LOG << "Wrote global CMVN stats to "
<< PrintableWxfilename(wxfilename);
}
KALDI_LOG << "Done accumulating CMVN stats for " << num_done
<< " utterances; " << num_err << " had errors.";
return (num_done != 0 ? 0 : 1);
}
catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
在compute-cmvn-stats.cc中定义了AccCmvnStatsWrapper这个函数
bool AccCmvnStatsWrapper(std::string utt,
const MatrixBase<BaseFloat> &feats,
RandomAccessBaseFloatVectorReader *weights_reader,
Matrix<double> *cmvn_stats) {
// weights_reader除非作为参数特别设置否则都是空的
if (!weights_reader->IsOpen()) {
AccCmvnStats(feats, NULL, cmvn_stats);
return true;
}
else {
if (!weights_reader->HasKey(utt)) {
KALDI_WARN << "No weights available for utterance " << utt;
return false;
}
const Vector<BaseFloat> &weights = weights_reader->Value(utt);
if (weights.Dim() != feats.NumRows()) {
KALDI_WARN << "Weights for utterance " << utt << " have wrong dimension "
<< weights.Dim() << " vs. " << feats.NumRows();
return false;
}
AccCmvnStats(feats, &weights, cmvn_stats);
return true;
}
}
在$KALDI_HOME/src/transform/cmvn.cc中定义了InitCmvnStats以及AccCmvnStats这两个函数
void InitCmvnStats(int32 dim, Matrix<double> *stats) {
KALDI_ASSERT(dim > 0); //如果dim<=0这里就会报错
stats->Resize(2, dim + 1); //将stats的rows cols变为[2,dim+1]
//本次实验中dim=16,因为feats是16维的,stats最终就是[2,17]维的,这里增加的第dim+1维存储的就是帧数
}
void AccCmvnStats(const VectorBase<BaseFloat> &feats, BaseFloat weight, MatrixBase<double> *stats) {
int32 dim = feats.Dim(); // 这里dim是特征的维度就是16
KALDI_ASSERT(stats != NULL);
KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() == dim + 1);
// Remove these __restrict__ modifiers if they cause compilation problems.
// It's just an optimization.
double *__restrict__ mean_ptr = stats->RowData(0), // mean_ptr是一个指针指向stats的第一行,也就是说stats的第一行是均值
*__restrict__ var_ptr = stats->RowData(1), // var_ptr是一个指针指向stats的第二行,第二行是方差
*__restrict__ count_ptr = mean_ptr + dim; // count_ptr也是一个指针,指向了最后一个一维所在的位置,mean_ptr+dim是一个地址,就是当前的第17维。count_ptr统计的是帧数(count of frames)
const BaseFloat * __restrict__ feats_ptr = feats.Data(); // feats_ptr也是一个指针,指向了feats中的特征本身
*count_ptr += weight; // 这里count_ptr的值是1,在开始的时候。每处理一帧都会变化一次
// Careful-- if we change the format of the matrix, the "mean_ptr < count_ptr"
// statement below might become wrong.
for (; mean_ptr < count_ptr; mean_ptr++, var_ptr++, feats_ptr++) {
*mean_ptr += *feats_ptr * weight; //mean_ptr统计的就是特征的带权重的和
*var_ptr += *feats_ptr * *feats_ptr * weight; //var_ptr就是特征带权重的平方
}
}
void AccCmvnStats(const MatrixBase<BaseFloat> &feats,
const VectorBase<BaseFloat> *weights,
MatrixBase<double> *stats) {
int32 num_frames = feats.NumRows();
if (weights != NULL) {
KALDI_ASSERT(weights->Dim() == num_frames);
}
for (int32 i = 0; i < num_frames; i++) {
// 在计算cmvn的时候我们是在一个utt中逐帧计算,这里的frames就是特征的行数,就是多少个语音帧
SubVector<BaseFloat> this_frame = feats.Row(i);
BaseFloat weight = (weights == NULL ? 1.0 : (*weights)(i));
if (weight != 0.0)
// weight默认会是1
AccCmvnStats(this_frame, weight, stats); //这里调用的是上边的同名函数
}
}