自顶向下分析一个简单的语音识别系统(五)

本回我们主要分析run_model中的configuration过程的相关函数。

1.run_model函数

第二回我们简单介绍了run_model函数的结构,现在我们贴出代码如下所示:

    def run_model(self):
        self.graph = tf.Graph()
        with self.graph.as_default(), tf.device('/cpu:0'):

            with tf.device(self.tf_device):
                # Run multiple functions on the specificed tf_device
                # tf_device GPU set in configs, but is overridden if not available
                # __init__函数中调用gpu_tool.check_if_gpu_available函数,如果设备中有gpu,则self.tf_device=/gpu:0
                self.setup_network_and_graph()
                self.load_placeholder_into_network()
                self.setup_loss_function()
                self.setup_optimizer()
                self.setup_decoder()

            self.setup_summary_statistics()

            # create the configuration for the session
            tf_config = tf.ConfigProto()
            tf_config.allow_soft_placement = True
            tf_config.gpu_options.per_process_gpu_memory_fraction = \
                (1.0 / self.simultaneous_users_count)
            #设置gpu中的内存最大占用率,self.simultaneous_users_count=4

            # create the session
            self.sess = tf.Session(config=tf_config)

            # initialize the summary writer
            self.writer = tf.summary.FileWriter(
                self.SUMMARY_DIR, graph=self.sess.graph)

            # Add ops to save and restore all the variables
            self.saver = tf.train.Saver()

            # For printing out section headers
            section = '\n{0:=^40}\n'

            # If there is a model_path declared, then restore the model
            #前述self.model_path=None
            if self.model_path is not None:
                self.saver.restore(self.sess, self.model_path)
            # If there is NOT a model_path declared, build the model from scratch
            else:
                # Op to initialize the variables
                init_op = tf.global_variables_initializer()

                # Initializate the weights and biases
                self.sess.run(init_op)

                # MAIN LOGIC for running the training epochs
                logger.info(section.format('Run training epoch'))
                self.run_training_epochs()

            logger.info(section.format('Decoding test data'))
            # make the assumption for working on the test data, that the epoch here is the last epoch
            _, self.test_ler = self.run_batches(self.data_sets.test, is_training=False,
                                                decode=True, write_to_file=False, epoch=self.epochs)

            # Add the final test data to the summary writer
            # (single point on the graph for end of training run)
            summary_line = self.sess.run(
                self.test_ler_op, {self.ler_placeholder: self.test_ler})
            self.writer.add_summary(summary_line, self.epochs)

            logger.info('Test Label Error Rate: {}'.format(self.test_ler))

            # save train summaries to disk
            self.writer.flush()

            self.sess.close()

2.setup_network_and_graph函数

本函数主要定义网络模型的输入输出placeholder,代码如下:

    def setup_network_and_graph(self):
        # e.g: log filter bank or MFCC features
        # shape = [batch_size, max_stepsize, n_input + (2 * n_input * n_context)]
        # the batch_size and max_stepsize can vary along each step
        self.input_tensor = tf.placeholder(
            tf.float32, [None, None, self.n_input + (2 * self.n_input * self.n_context)], name='input')

        # Use sparse_placeholder; will generate a SparseTensor, required by ctc_loss op.
        self.targets = tf.sparse_placeholder(tf.int32, name='targets')
        # 1d array of size [batch_size]
        self.seq_length = tf.placeholder(tf.int32, [None], name='seq_length')

其中,n_input=26表示MFCC倒谱系数为26位,n_context=9表示当前25ms声音片段往前和往后分别9个声音片段做输入。MFCC将在后面详细分析。

3.load_placeholder_into_network函数

该函数调用rnn.py中的SimpleLSTM/BiRNN函数构建网路的基本结构,代码如下:

    def load_placeholder_into_network(self):
        # logits is the non-normalized output/activations from the last layer.
        # logits will be input for the loss function.
        # nn_model is from the import statement in the load_model function
        # summary_op variables are for tensorboard
        if self.network_type == 'SimpleLSTM':
            self.logits, summary_op = SimpleLSTM_model(
                self.conf_path,
                self.input_tensor,
                tf.to_int64(self.seq_length)
            )
        elif self.network_type == 'BiRNN':
            self.logits, summary_op = BiRNN_model(
                self.conf_path,
                self.input_tensor,
                tf.to_int64(self.seq_length),
                self.n_input,
                self.n_context
            )
        else:
            raise ValueError('network_type must be SimpleLSTM or BiRNN')
        self.summary_op = tf.summary.merge([summary_op])

由前面可知,neural_network.ini中network_type=BIRNN,下回我们将详细分析该网络。

4.setup_loss_function函数

本函数设置语音识别模型的loss函数为ctc_loss,代码如下:

    def setup_loss_function(self):
        with tf.name_scope("loss"):
            self.total_loss = ctc_ops.ctc_loss(
                self.targets, self.logits, self.seq_length)
            self.avg_loss = tf.reduce_mean(self.total_loss)
            self.loss_summary = tf.summary.scalar("avg_loss", self.avg_loss)

            self.cost_placeholder = tf.placeholder(dtype=tf.float32, shape=[])

            self.train_cost_op = tf.summary.scalar(
                "train_avg_loss", self.cost_placeholder)

后面将详细分析CTC损失函数。

5.setup_optimizer函数

本函数调用utils.py中的create_optimizer函数,使用AdamOptimizer对网络进行优化,代码如下:

    def setup_optimizer(self):
        # Note: The optimizer is created in models/RNN/utils.py
        with tf.name_scope("train"):
            self.optimizer = create_optimizer()
            self.optimizer = self.optimizer.minimize(self.avg_loss)

6.setup_decoder函数

本函数使用ctc中的两种策略对输出结果进行解码,代码如下:

    def setup_decoder(self):
        with tf.name_scope("decode"):
            if self.beam_search_decoder == 'default':
                self.decoded, self.log_prob = ctc_ops.ctc_beam_search_decoder(
                    self.logits, self.seq_length, merge_repeated=False)
            elif self.beam_search_decoder == 'greedy':
                self.decoded, self.log_prob = ctc_ops.ctc_greedy_decoder(
                    self.logits, self.seq_length, merge_repeated=False)
            else:
                logging.warning("Invalid beam search decoder option selected!")

7.setup_summary_statistics函数

本函数主要用于设置运行过程中产生的summary的收集点,代码如下:

    def setup_summary_statistics(self):
        # Create a placholder for the summary statistics
        with tf.name_scope("accuracy"):
            # Compute the edit (Levenshtein) distance of the top path
            distance = tf.edit_distance(
                tf.cast(self.decoded[0], tf.int32), self.targets)

            # Compute the label error rate (accuracy)
            self.ler = tf.reduce_mean(distance, name='label_error_rate')
            self.ler_placeholder = tf.placeholder(dtype=tf.float32, shape=[])
            self.train_ler_op = tf.summary.scalar(
                "train_label_error_rate", self.ler_placeholder)
            self.dev_ler_op = tf.summary.scalar(
                "validation_label_error_rate", self.ler_placeholder)
            self.test_ler_op = tf.summary.scalar(
                "test_label_error_rate", self.ler_placeholder)

本回简要分析了网络的configuration过程,下回将仔细分析网络的基本结构。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值