上一回我们简要分析了Tf_train_ctc这个类的主要结构,本回我们主要讲Tf_train_ctc的__init__ 函数是如何初始化的。
1. __init__ 函数
代码如下:
def __init__(self,
config_file='neural_network.ini',
model_name=None,
debug=False):
# set TF logging verbosity
tf.logging.set_verbosity(tf.logging.INFO)
# Load the configuration file depending on debug True/False
self.debug = debug
self.conf_path = get_conf_dir(debug=self.debug)
self.conf_path = os.path.join(self.conf_path, config_file)
self.load_configs()
# Verify that the GPU is operational, if not use CPU
if not gpu_tool.check_if_gpu_available(self.tf_device):
self.tf_device = '/cpu:0'
logging.info('Using this device for main computations: %s', self.tf_device)
# set the directories
self.set_up_directories(model_name)
# set up the model
self.set_up_model()
其中,
get_conf_dir函数属于set_dirs.py,返回neural_network.ini所在目录,留待下回分析;
check_if_gpu_available函数属于gpu.py,用于查看系统中是否存在有适合用于计算的gpu。
2.load_configs函数
上面介绍到get_config_dir返回neural_network.ini所在目录,本函数利用该路径读取neural_network.ini中[nn]参数,配置训练时所用参数,具体代码如下:
def load_configs(self):
parser = ConfigParser(os.environ)
if not os.path.exists(self.conf_path):
raise IOError("Configuration file '%s' does not exist" % self.conf_path)
logging.info('Loading config from %s', self.conf_path)
parser.read(self.conf_path)
# set which set of configs to import
config_header = 'nn'
logger.info('config header: %s', config_header)
self.epochs = parser.getint(config_header, 'epochs')
logger.debug('self.epochs = %d', self.epochs)
self.network_type = parser.get(config_header, 'network_type')
# Number of mfcc features, 13 or 26
self.n_input = parser.getint(config_header, 'n_input')
# Number of contextual samples to include
self.n_context = parser.getint(config_header, 'n_context')
# self.decode_train = parser.getboolean(config_header, 'decode_train')
# self.random_seed = parser.getint(config_header, 'random_seed')
self.model_dir = parser.get(config_header, 'model_dir')
# set the session name
self.session_name = '{}_{}'.format(
self.network_type, time.strftime("%Y%m%d-%H%M%S"))
sess_prefix_str = 'develop'
if len(sess_prefix_str) > 0:
self.session_name = '{}_{}'.format(
sess_prefix_str, self.session_name)
# How often to save the model
self.SAVE_MODEL_EPOCH_NUM = parser.getint(
config_header, 'SAVE_MODEL_EPOCH_NUM')
# decode dev set after N epochs
self.VALIDATION_EPOCH_NUM = parser.getint(
config_header, 'VALIDATION_EPOCH_NUM')
# decide when to stop training prematurely
self.CURR_VALIDATION_LER_DIFF = parser.getfloat(
config_header, 'CURR_VALIDATION_LER_DIFF')
self.AVG_VALIDATION_LER_EPOCHS = parser.getint(
config_header, 'AVG_VALIDATION_LER_EPOCHS')
# initialize list to hold average validation at end of each epoch
self.AVG_VALIDATION_LERS = [
1.0 for _ in range(self.AVG_VALIDATION_LER_EPOCHS)]
# setup type of decoder
self.beam_search_decoder = parser.get(
config_header, 'beam_search_decoder')
# determine if the data input order should be shuffled after every epic
self.shuffle_data_after_epoch = parser.getboolean(
config_header, 'shuffle_data_after_epoch')
# initialize to store the minimum validation set label error rate
self.min_dev_ler = parser.getfloat(config_header, 'min_dev_ler')
# set up GPU if available
self.tf_device = str(parser.get(config_header, 'tf_device'))
# set up the max amount of simultaneous users
# this restricts GPU usage to the inverse of self.simultaneous_users_count
self.simultaneous_users_count = parser.getint(config_header, 'simultaneous_users_count')
neural_network.ini参数如下图所示:
3.set_up_directories函数
本函数主要实现设置session和summary的存放目录,具体代码如下:
def set_up_directories(self, model_name):
# Set up model directory
self.model_dir = os.path.join(get_model_dir(), self.model_dir)
#self.model_dir在load_configs()中返回nn/debug_models目录,故此时它的值为RNN-Tutorial/models/nn/debug_models
# summary will contain logs
self.SUMMARY_DIR = os.path.join(
self.model_dir, "summary", self.session_name)
# session will contain models
self.SESSION_DIR = os.path.join(
self.model_dir, "session", self.session_name)
#self.session_name在load_configs中被配置为develop_BiRNN_(当前时间)形式的字符串
if not os.path.exists(self.SESSION_DIR):
os.makedirs(self.SESSION_DIR)
if not os.path.exists(self.SUMMARY_DIR):
os.makedirs(self.SUMMARY_DIR)
# set the model name and restore if not None
if model_name is not None:
self.model_path = os.path.join(self.SESSION_DIR, model_name)
else:
self.model_path = None
#默认设置时self.model_path=None
其中,get_model_dir函数属于set_dirs.py,返回RNN-Tutorial/models目录。
4.set_up_model函数
该函数读回data中的数据,代码如下:
def set_up_model(self):
self.sets = ['train', 'dev', 'test']
# read data set, inherits configuration path
# to parse the config file for where data lives
self.data_sets = read_datasets(self.conf_path,
self.sets,
self.n_input,
self.n_context
)
self.n_examples_train = len(self.data_sets.train._txt_files)
self.n_examples_dev = len(self.data_sets.dev._txt_files)
self.n_examples_test = len(self.data_sets.test._txt_files)
self.batch_size = self.data_sets.train._batch_size
self.n_batches_per_epoch = int(np.ceil(
self.n_examples_train / self.batch_size))
logger.info('''Training model: {}
Train examples: {:,}
Dev examples: {:,}
Test examples: {:,}
Epochs: {}
Training batch size: {}
Batches per epoch: {}'''.format(
self.session_name,
self.n_examples_train,
self.n_examples_dev,
self.n_examples_test,
self.epochs,
self.batch_size,
self.n_batches_per_epoch))
其中,read_datasets函数属于datasets.py,将data目录下的数据读出,并保存到data_sets object中。后面会详细分析该函数。
以上是整个语音识别模型的初始化过程,下回我们分析一下本回反复用到的set_dirs.py文件。