自顶向下分析一个简单的语音识别系统(七)

上回我们分析了系统网络的基本结构,那么我们的网络又是如何训练的呢?要回答这个问题,我们先得回答我们的数据是如何获得的,这回我们就来分析一下这个过程。

1.调用关系图

数据首先通过Tf_train_ctc类中的set_up_model函数调用datasets.py中的read_datasets函数返回一个DataSet类型的对象,该对象包含处理数据相关的配置信息。然后Tf_train_ctc类使用上一步返回的self.data_sets在run_batches函数中调用DataSet类中的next_batch函数返回一次训练所需的batch。具体调用关系图如下所示:

Created with Raphaël 2.1.0 Tf_train_ctc Tf_train_ctc DataSet DataSet read_datasets self.data_sets run_batches source, source_lengths, sparse_labels

下面开始结合代码进行详细分析。

2.read_datasets函数

该函数返回一个带有数据处理配置信息的DataSet类型对象,具体代码如下:

def read_datasets(conf_path, sets, numcep, numcontext,
                  thread_count=8):
    data_dir, dataset_config = _get_data_set_dict(conf_path, sets)

    def _read_data_set(config):
        path = os.path.join(data_dir, config['dir_pattern'])
        return DataSet.from_directory(path,
                                      thread_count=thread_count,
                                      batch_size=config['batch_size'],
                                      numcep=numcep,
                                      numcontext=numcontext,
                                      start_idx=config['start_idx'],
                                      limit=config['limit'],
                                      sort=config['sort']
                                      )
    datasets = {name: _read_data_set(dataset_config[name])
                      if name in sets else None
                for name in ('train', 'dev', 'test')}
                #sets传入的值为['train', 'dev', 'test']
    return DataSets(**datasets)

由上面代码可知,其主要调用了_get_data_set_dict函数和DataSet类中的from_directory函数。
_get_data_set_dict函数代码如下图所示:

def _get_data_set_dict(conf_path, sets):
    parser = ConfigParser(os.environ)
    parser.read(conf_path)
    config_header = 'data'
    data_dir = get_data_dir(parser.get(config_header, 'data_dir'))
    data_dict = {}

    if 'train' in sets:
        d = {}
        d['dir_pattern'] = parser.get(config_header, 'dir_pattern_train')
        d['limit'] = parser.getint(config_header, 'n_train_limit')
        d['sort'] = parser.get(config_header, 'sort_train')
        d['batch_size'] = parser.getint(config_header, 'batch_size_train')
        d['start_idx'] = parser.getint(config_header, 'start_idx_init_train')
        data_dict['train'] = d
        logging.debug('Training configuration: %s', str(d))

    if 'dev' in sets:
        d = {}
        d['dir_pattern'] = parser.get(config_header, 'dir_pattern_dev')
        d['limit'] = parser.getint(config_header, 'n_dev_limit')
        d['sort'] = parser.get(config_header, 'sort_dev')
        d['batch_size'] = parser.getint(config_header, 'batch_size_dev')
        d['start_idx'] = parser.getint(config_header, 'start_idx_init_dev')
        data_dict['dev'] = d
        logging.debug('Dev configuration: %s', str(d))

    if 'test' in sets:
        d = {}
        d['dir_pattern'] = parser.get(config_header, 'dir_pattern_test')
        d['limit'] = parser.getint(config_header, 'n_test_limit')
        d['sort'] = parser.get(config_header, 'sort_test')
        d['batch_size'] = parser.getint(config_header, 'batch_size_test')
        d['start_idx'] = parser.getint(config_header, 'start_idx_init_test')
        data_dict['test'] = d
        logging.debug('Test configuration: %s', str(d))

    return data_dir, data_dict

这段代码读入了neural_network.ini配置文件中[data]相关信息,返回了data_dict对象是一个包含有数据处理相关配置信息的dict类型对象,具体如下图所示:
这里写图片描述
得到配置信息之后,_read_data_set函数通过调用DataSet.from_directory函数获得最终的self.data_sets对象。现在我们来分析from_directory函数。

3.from_directory函数

DataSet类中起主要作用的函数为from_directory函数和next_batch函数。from_directory函数用来从data目录下构建之后需要用到的DataSet对象,具体代码如下:

    def from_directory(cls, dirpath, thread_count, batch_size, numcep, numcontext, start_idx=0, limit=0, sort=None):
        if not os.path.exists(dirpath):
            raise IOError("'%s' does not exist" % dirpath)
        txt_files = txt_filenames(dirpath, start_idx=start_idx, limit=limit, sort=sort)
        if len(txt_files) == 0:
            raise RuntimeError('start_idx=%d and limit=%d arguments result in zero files' % (start_idx, limit))
        return cls(txt_files, thread_count, batch_size, numcep, numcontext)

可以看出这段代码主要是调用txt_filenames返回数据文件的命名列表,具体代码如下:

def txt_filenames(dataset_path, start_idx=0, limit=None, sort='alpha'):
        # Obtain list of txt files
        txt_files = glob(os.path.join(dataset_path, "*.txt"))
        limit = limit or len(txt_files)

        # Optional: sort files to improve padding performance
        if sort not in SORTS:
            raise ValueError('sort must be one of [%s]', SORTS)
        reverse = False
        key = None
        if 'filesize' in sort:
            key = os.path.getsize
        if sort == 'filesize_high_low':
            reverse = True
        elif sort == 'random':
            key = lambda *args: random()
        txt_files = sorted(txt_files, key=key, reverse=reverse)

        return txt_files[start_idx:limit + start_idx]

由上面可以看出,这段代码调用定义好的Sort方法,返回长度为limit的txt_files列表(文件路径)。

4.next_batch函数

通过上面的分析我们可以知道,next_batch函数被Tf_train_ctc类中的run_batches函数调用,返回具体需要训练的数据batch。由于具体的训练过程还没有分析,我们现在只是局部介绍next_batch函数,看看该函数返回了怎样的数据,具体代码如下图所示:

    def next_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self._batch_size

        end_idx = min(len(self._txt_files), self._start_idx + batch_size)
        idx_list = range(self._start_idx, end_idx)
        txt_files = [self._txt_files[i] for i in idx_list]
        wav_files = [x.replace('.txt', '.wav') for x in txt_files]
        (source, _, target, _) = get_audio_and_transcript(txt_files,
                                                          wav_files,
                                                          self._numcep,
                                                          self._numcontext)

        self._start_idx += batch_size
        # Verify that the start_idx is not larger than total available sample size
        if self._start_idx >= self.size:
            self._start_idx = 0

        # Pad input to max_time_step of this batch
        source, source_lengths = pad_sequences(source)
        sparse_labels = sparse_tuple_from(target)
        return source, source_lengths, sparse_labels

由代码中可以看出每次取出batch_size个训练数据,同时要求每个训练数据的文本txt的名字与波形文件wav的名字保持相同。该函数主要调用get_audio_and_transcript函数、pad_sequences函数和sparse_tuple_from函数,这三个函数分别属于load_audio_to_mem.py和text.py中,负责具体获得相关的输入输出向量,留待下回分解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值