自顶向下分析一个简单的语音识别系统(三)

上一回我们简要分析了Tf_train_ctc这个类的主要结构,本回我们主要讲Tf_train_ctc的__init__ 函数是如何初始化的。

1. __init__ 函数

代码如下:

    def __init__(self,
                 config_file='neural_network.ini',
                 model_name=None,
                 debug=False):
        # set TF logging verbosity
               tf.logging.set_verbosity(tf.logging.INFO)

        # Load the configuration file depending on debug True/False
        self.debug = debug
        self.conf_path = get_conf_dir(debug=self.debug)
        self.conf_path = os.path.join(self.conf_path, config_file)
        self.load_configs()

        # Verify that the GPU is operational, if not use CPU
        if not gpu_tool.check_if_gpu_available(self.tf_device):
            self.tf_device = '/cpu:0'
        logging.info('Using this device for main computations: %s', self.tf_device)

        # set the directories
        self.set_up_directories(model_name)

        # set up the model
        self.set_up_model()

其中,
get_conf_dir函数属于set_dirs.py,返回neural_network.ini所在目录,留待下回分析;
check_if_gpu_available函数属于gpu.py,用于查看系统中是否存在有适合用于计算的gpu。

2.load_configs函数

上面介绍到get_config_dir返回neural_network.ini所在目录,本函数利用该路径读取neural_network.ini中[nn]参数,配置训练时所用参数,具体代码如下:

    def load_configs(self):
        parser = ConfigParser(os.environ)
        if not os.path.exists(self.conf_path):
            raise IOError("Configuration file '%s' does not exist" % self.conf_path)
        logging.info('Loading config from %s', self.conf_path)
        parser.read(self.conf_path)

        # set which set of configs to import
        config_header = 'nn'

        logger.info('config header: %s', config_header)

        self.epochs = parser.getint(config_header, 'epochs')
        logger.debug('self.epochs = %d', self.epochs)

        self.network_type = parser.get(config_header, 'network_type')

        # Number of mfcc features, 13 or 26
        self.n_input = parser.getint(config_header, 'n_input')

        # Number of contextual samples to include
        self.n_context = parser.getint(config_header, 'n_context')

        # self.decode_train = parser.getboolean(config_header, 'decode_train')
        # self.random_seed = parser.getint(config_header, 'random_seed')
        self.model_dir = parser.get(config_header, 'model_dir')

        # set the session name
        self.session_name = '{}_{}'.format(
            self.network_type, time.strftime("%Y%m%d-%H%M%S"))
        sess_prefix_str = 'develop'
        if len(sess_prefix_str) > 0:
            self.session_name = '{}_{}'.format(
                sess_prefix_str, self.session_name)

        # How often to save the model
        self.SAVE_MODEL_EPOCH_NUM = parser.getint(
            config_header, 'SAVE_MODEL_EPOCH_NUM')

        # decode dev set after N epochs
        self.VALIDATION_EPOCH_NUM = parser.getint(
            config_header, 'VALIDATION_EPOCH_NUM')

        # decide when to stop training prematurely
        self.CURR_VALIDATION_LER_DIFF = parser.getfloat(
            config_header, 'CURR_VALIDATION_LER_DIFF')

        self.AVG_VALIDATION_LER_EPOCHS = parser.getint(
            config_header, 'AVG_VALIDATION_LER_EPOCHS')
        # initialize list to hold average validation at end of each epoch
        self.AVG_VALIDATION_LERS = [
            1.0 for _ in range(self.AVG_VALIDATION_LER_EPOCHS)]

        # setup type of decoder
        self.beam_search_decoder = parser.get(
            config_header, 'beam_search_decoder')

        # determine if the data input order should be shuffled after every epic
        self.shuffle_data_after_epoch = parser.getboolean(
            config_header, 'shuffle_data_after_epoch')

        # initialize to store the minimum validation set label error rate
        self.min_dev_ler = parser.getfloat(config_header, 'min_dev_ler')

        # set up GPU if available
        self.tf_device = str(parser.get(config_header, 'tf_device'))

        # set up the max amount of simultaneous users
        # this restricts GPU usage to the inverse of self.simultaneous_users_count
        self.simultaneous_users_count = parser.getint(config_header, 'simultaneous_users_count')

neural_network.ini参数如下图所示:
这里写图片描述

3.set_up_directories函数

本函数主要实现设置session和summary的存放目录,具体代码如下:

    def set_up_directories(self, model_name):
        # Set up model directory
        self.model_dir = os.path.join(get_model_dir(), self.model_dir)
        #self.model_dir在load_configs()中返回nn/debug_models目录,故此时它的值为RNN-Tutorial/models/nn/debug_models

        # summary will contain logs
        self.SUMMARY_DIR = os.path.join(
            self.model_dir, "summary", self.session_name)
        # session will contain models
        self.SESSION_DIR = os.path.join(
            self.model_dir, "session", self.session_name)
        #self.session_name在load_configs中被配置为develop_BiRNN_(当前时间)形式的字符串

        if not os.path.exists(self.SESSION_DIR):
            os.makedirs(self.SESSION_DIR)
        if not os.path.exists(self.SUMMARY_DIR):
            os.makedirs(self.SUMMARY_DIR)

        # set the model name and restore if not None
        if model_name is not None:
            self.model_path = os.path.join(self.SESSION_DIR, model_name)
        else:
            self.model_path = None
        #默认设置时self.model_path=None

其中,get_model_dir函数属于set_dirs.py,返回RNN-Tutorial/models目录。

4.set_up_model函数

该函数读回data中的数据,代码如下:

    def set_up_model(self):
        self.sets = ['train', 'dev', 'test']

        # read data set, inherits configuration path
        # to parse the config file for where data lives
        self.data_sets = read_datasets(self.conf_path,
                                       self.sets,
                                       self.n_input,
                                       self.n_context
                                       )

        self.n_examples_train = len(self.data_sets.train._txt_files)
        self.n_examples_dev = len(self.data_sets.dev._txt_files)
        self.n_examples_test = len(self.data_sets.test._txt_files)
        self.batch_size = self.data_sets.train._batch_size
        self.n_batches_per_epoch = int(np.ceil(
            self.n_examples_train / self.batch_size))

        logger.info('''Training model: {}
        Train examples: {:,}
        Dev examples: {:,}
        Test examples: {:,}
        Epochs: {}
        Training batch size: {}
        Batches per epoch: {}'''.format(
            self.session_name,
            self.n_examples_train,
            self.n_examples_dev,
            self.n_examples_test,
            self.epochs,
            self.batch_size,
            self.n_batches_per_epoch))

其中,read_datasets函数属于datasets.py,将data目录下的数据读出,并保存到data_sets object中。后面会详细分析该函数。
以上是整个语音识别模型的初始化过程,下回我们分析一下本回反复用到的set_dirs.py文件。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值