自顶向下分析一个简单的语音识别系统(九)

前面几回,我们分析完了run_model函数的configuration过程以及数据的输入输出向量的生成,本回我们继续分析一下接下来具体的训练过程。

1.run_training_epochs函数

训练主要是通过这个函数实现的,代码如下所示:

    def run_training_epochs(self):
        train_start = time.time()
        for epoch in range(self.epochs):
            # Initialize variables that can be updated
            #配置信息中读入self.epochs=2
            save_dev_model = False
            stop_training = False
            is_checkpoint_step, is_validation_step = \
                self.validation_and_checkpoint_check(epoch)

            epoch_start = time.time()

            self.train_cost, self.train_ler = self.run_batches(
                self.data_sets.train,
                is_training=True,
                decode=False,
                write_to_file=False,
                epoch=epoch)

            epoch_duration = time.time() - epoch_start

            log = 'Epoch {}/{}, train_cost: {:.3f}, train_ler: {:.3f}, time: {:.2f} sec'
            logger.info(log.format(
                epoch + 1,
                self.epochs,
                self.train_cost,
                self.train_ler,
                epoch_duration))

            summary_line = self.sess.run(
                self.train_ler_op, {self.ler_placeholder: self.train_ler})
            self.writer.add_summary(summary_line, epoch)

            summary_line = self.sess.run(
                self.train_cost_op, {self.cost_placeholder: self.train_cost})
            self.writer.add_summary(summary_line, epoch)

            # Shuffle the data for the next epoch
            if self.shuffle_data_after_epoch:
                np.random.shuffle(self.data_sets.train._txt_files)

            # Run validation if it was determined to run a validation step
            if is_validation_step:
                self.run_validation_step(epoch)

            if (epoch + 1) == self.epochs or is_checkpoint_step:
                # save the final model
                save_path = self.saver.save(self.sess, os.path.join(
                    self.SESSION_DIR, 'model.ckpt'), epoch)
                logger.info("Model saved: {}".format(save_path))

            if save_dev_model:
                # If the dev set is not improving,
                # the training is killed to prevent overfitting
                # And then save the best validation performance model
                save_path = self.saver.save(self.sess, os.path.join(
                    self.SESSION_DIR, 'model-best.ckpt'))
                logger.info(
                    "Model with best validation label error rate saved: {}".
                    format(save_path))

            if stop_training:
                break

        train_duration = time.time() - train_start
        logger.info('Training complete, total duration: {:.2f} min'.format(
            train_duration / 60))

第8-9行得到是否是check_step和validation_step;
第13-18行将data_sets.train数据给入run_batches函数中进行训练;
第30-32行调用sess.run进行计算;
第40行表示是否在一次训练之后,打乱训练数据;
第44行表示是否进行validation过程;
第46-60行表示保存训练模型参数;
可以看出,该函数的关键部分是run_batches函数,下面我们开始分析这个函数。

2.run_batches函数

    def run_batches(self, dataset, is_training, decode, write_to_file, epoch):
        n_examples = len(dataset._txt_files)

        n_batches_per_epoch = int(np.ceil(n_examples / dataset._batch_size))

        self.train_cost = 0
        self.train_ler = 0

        for batch in range(n_batches_per_epoch):
            # Get next batch of training data (audio features) and transcripts
            source, source_lengths, sparse_labels = dataset.next_batch()

            feed = {self.input_tensor: source,
                    self.targets: sparse_labels,
                    self.seq_length: source_lengths}

            # If the is_training is false, this means straight decoding without computing loss
            if is_training:
                # avg_loss is the loss_op, optimizer is the train_op;
                # running these pushes tensors (data) through graph
                batch_cost, _ = self.sess.run(
                    [self.avg_loss, self.optimizer], feed)
                self.train_cost += batch_cost * dataset._batch_size
                logger.debug('Batch cost: %.2f | Train cost: %.2f',
                             batch_cost, self.train_cost)

            self.train_ler += self.sess.run(self.ler, feed_dict=feed) * dataset._batch_size
            logger.debug('Label error rate: %.2f', self.train_ler)

            # Turn on decode only 1 batch per epoch
            if decode and batch == 0:
                d = self.sess.run(self.decoded[0], feed_dict={
                    self.input_tensor: source,
                    self.targets: sparse_labels,
                    self.seq_length: source_lengths}
                )
                dense_decoded = tf.sparse_tensor_to_dense(
                    d, default_value=-1).eval(session=self.sess)
                dense_labels = sparse_tuple_to_texts(sparse_labels)

                # only print a set number of example translations
                counter = 0
                counter_max = 4
                if counter < counter_max:
                    for orig, decoded_arr in zip(dense_labels, dense_decoded):
                        # convert to strings
                        decoded_str = ndarray_to_text(decoded_arr)
                        logger.info('Batch {}, file {}'.format(batch, counter))
                        logger.info('Original: {}'.format(orig))
                        logger.info('Decoded:  {}'.format(decoded_str))
                        counter += 1

                # save out variables for testing
                self.dense_decoded = dense_decoded
                self.dense_labels = dense_labels

        # Metrics mean
        if is_training:
            self.train_cost /= n_examples
        self.train_ler /= n_examples

        # Populate summary for histograms and distributions in tensorboard
        self.accuracy, summary_line = self.sess.run(
            [self.avg_loss, self.summary_op], feed)
        self.writer.add_summary(summary_line, epoch)

        return self.train_cost, self.train_ler

第13-15行表示sess.run时指定的feed_dict;
第18-25行表示训练并得到相应的cost;
第31-55行表示decode获得的输出序列。

3.validation_and_checkpoint_check函数

前面提到该函数是为了得到存储模型和验证模型的时间点,具体代码如下:

    def validation_and_checkpoint_check(self, epoch):
        # initially set at False unless indicated to change
        is_checkpoint_step = False
        is_validation_step = False

        # Check if the current epoch is a validation or checkpoint step
        if (epoch > 0) and ((epoch + 1) != self.epochs):
            if (epoch + 1) % self.SAVE_MODEL_EPOCH_NUM == 0:
                is_checkpoint_step = True
            if (epoch + 1) % self.VALIDATION_EPOCH_NUM == 0:
                is_validation_step = True

        return is_checkpoint_step, is_validation_step

SAVE_MODEL_EPOCH_NUM和VALIDATION_EPOCH_NUM均在配置文件中配置,该函数保证在固定的周期对网络模型进行存储和验证。

4.run_validation_step函数

上面可以看出在对模型进行一定次数的训练之后,我们可以调用run_validation_step函数对模型进行验证,具体代码如下:

    def run_validation_step(self, epoch):
        dev_ler = 0

        _, dev_ler = self.run_batches(self.data_sets.dev,
                                      is_training=False,
                                      decode=True,
                                      write_to_file=False,
                                      epoch=epoch)

        logger.info('Validation Label Error Rate: {}'.format(dev_ler))

        summary_line = self.sess.run(
            self.dev_ler_op, {self.ler_placeholder: dev_ler})
        self.writer.add_summary(summary_line, epoch)

        if dev_ler < self.min_dev_ler:
            self.min_dev_ler = dev_ler

        # average historical LER
        history_avg_ler = np.mean(self.AVG_VALIDATION_LERS)

        # if this LER is not better than average of previous epochs, exit
        if history_avg_ler - dev_ler <= self.CURR_VALIDATION_LER_DIFF:
            log = "Validation label error rate not improved by more than {:.2%} \
                  after {} epochs. Exit"
            warnings.warn(log.format(self.CURR_VALIDATION_LER_DIFF,
                                     self.AVG_VALIDATION_LER_EPOCHS))

        # save avg validation accuracy in the next slot
        self.AVG_VALIDATION_LERS[
            epoch % self.AVG_VALIDATION_LER_EPOCHS] = dev_ler

由上面代码可以看出验证主要使用self.data_sets.dev中数据,如果验证错误率不比前面的平均错误率高的话,给出相关的warning。
至此,整个训练过程的代码我们都分析完了,还有疑问的是对输出向量decode的时候调用的sparse_tuple_to_texts函数和ndarray_to_text函数还没有分析。我们留待下回细细分解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值