前面几回,我们分析完了run_model函数的configuration过程以及数据的输入输出向量的生成,本回我们继续分析一下接下来具体的训练过程。
1.run_training_epochs函数
训练主要是通过这个函数实现的,代码如下所示:
def run_training_epochs(self):
train_start = time.time()
for epoch in range(self.epochs):
# Initialize variables that can be updated
#配置信息中读入self.epochs=2
save_dev_model = False
stop_training = False
is_checkpoint_step, is_validation_step = \
self.validation_and_checkpoint_check(epoch)
epoch_start = time.time()
self.train_cost, self.train_ler = self.run_batches(
self.data_sets.train,
is_training=True,
decode=False,
write_to_file=False,
epoch=epoch)
epoch_duration = time.time() - epoch_start
log = 'Epoch {}/{}, train_cost: {:.3f}, train_ler: {:.3f}, time: {:.2f} sec'
logger.info(log.format(
epoch + 1,
self.epochs,
self.train_cost,
self.train_ler,
epoch_duration))
summary_line = self.sess.run(
self.train_ler_op, {self.ler_placeholder: self.train_ler})
self.writer.add_summary(summary_line, epoch)
summary_line = self.sess.run(
self.train_cost_op, {self.cost_placeholder: self.train_cost})
self.writer.add_summary(summary_line, epoch)
# Shuffle the data for the next epoch
if self.shuffle_data_after_epoch:
np.random.shuffle(self.data_sets.train._txt_files)
# Run validation if it was determined to run a validation step
if is_validation_step:
self.run_validation_step(epoch)
if (epoch + 1) == self.epochs or is_checkpoint_step:
# save the final model
save_path = self.saver.save(self.sess, os.path.join(
self.SESSION_DIR, 'model.ckpt'), epoch)
logger.info("Model saved: {}".format(save_path))
if save_dev_model:
# If the dev set is not improving,
# the training is killed to prevent overfitting
# And then save the best validation performance model
save_path = self.saver.save(self.sess, os.path.join(
self.SESSION_DIR, 'model-best.ckpt'))
logger.info(
"Model with best validation label error rate saved: {}".
format(save_path))
if stop_training:
break
train_duration = time.time() - train_start
logger.info('Training complete, total duration: {:.2f} min'.format(
train_duration / 60))
第8-9行得到是否是check_step和validation_step;
第13-18行将data_sets.train数据给入run_batches函数中进行训练;
第30-32行调用sess.run进行计算;
第40行表示是否在一次训练之后,打乱训练数据;
第44行表示是否进行validation过程;
第46-60行表示保存训练模型参数;
可以看出,该函数的关键部分是run_batches函数,下面我们开始分析这个函数。
2.run_batches函数
def run_batches(self, dataset, is_training, decode, write_to_file, epoch):
n_examples = len(dataset._txt_files)
n_batches_per_epoch = int(np.ceil(n_examples / dataset._batch_size))
self.train_cost = 0
self.train_ler = 0
for batch in range(n_batches_per_epoch):
# Get next batch of training data (audio features) and transcripts
source, source_lengths, sparse_labels = dataset.next_batch()
feed = {self.input_tensor: source,
self.targets: sparse_labels,
self.seq_length: source_lengths}
# If the is_training is false, this means straight decoding without computing loss
if is_training:
# avg_loss is the loss_op, optimizer is the train_op;
# running these pushes tensors (data) through graph
batch_cost, _ = self.sess.run(
[self.avg_loss, self.optimizer], feed)
self.train_cost += batch_cost * dataset._batch_size
logger.debug('Batch cost: %.2f | Train cost: %.2f',
batch_cost, self.train_cost)
self.train_ler += self.sess.run(self.ler, feed_dict=feed) * dataset._batch_size
logger.debug('Label error rate: %.2f', self.train_ler)
# Turn on decode only 1 batch per epoch
if decode and batch == 0:
d = self.sess.run(self.decoded[0], feed_dict={
self.input_tensor: source,
self.targets: sparse_labels,
self.seq_length: source_lengths}
)
dense_decoded = tf.sparse_tensor_to_dense(
d, default_value=-1).eval(session=self.sess)
dense_labels = sparse_tuple_to_texts(sparse_labels)
# only print a set number of example translations
counter = 0
counter_max = 4
if counter < counter_max:
for orig, decoded_arr in zip(dense_labels, dense_decoded):
# convert to strings
decoded_str = ndarray_to_text(decoded_arr)
logger.info('Batch {}, file {}'.format(batch, counter))
logger.info('Original: {}'.format(orig))
logger.info('Decoded: {}'.format(decoded_str))
counter += 1
# save out variables for testing
self.dense_decoded = dense_decoded
self.dense_labels = dense_labels
# Metrics mean
if is_training:
self.train_cost /= n_examples
self.train_ler /= n_examples
# Populate summary for histograms and distributions in tensorboard
self.accuracy, summary_line = self.sess.run(
[self.avg_loss, self.summary_op], feed)
self.writer.add_summary(summary_line, epoch)
return self.train_cost, self.train_ler
第13-15行表示sess.run时指定的feed_dict;
第18-25行表示训练并得到相应的cost;
第31-55行表示decode获得的输出序列。
3.validation_and_checkpoint_check函数
前面提到该函数是为了得到存储模型和验证模型的时间点,具体代码如下:
def validation_and_checkpoint_check(self, epoch):
# initially set at False unless indicated to change
is_checkpoint_step = False
is_validation_step = False
# Check if the current epoch is a validation or checkpoint step
if (epoch > 0) and ((epoch + 1) != self.epochs):
if (epoch + 1) % self.SAVE_MODEL_EPOCH_NUM == 0:
is_checkpoint_step = True
if (epoch + 1) % self.VALIDATION_EPOCH_NUM == 0:
is_validation_step = True
return is_checkpoint_step, is_validation_step
SAVE_MODEL_EPOCH_NUM和VALIDATION_EPOCH_NUM均在配置文件中配置,该函数保证在固定的周期对网络模型进行存储和验证。
4.run_validation_step函数
上面可以看出在对模型进行一定次数的训练之后,我们可以调用run_validation_step函数对模型进行验证,具体代码如下:
def run_validation_step(self, epoch):
dev_ler = 0
_, dev_ler = self.run_batches(self.data_sets.dev,
is_training=False,
decode=True,
write_to_file=False,
epoch=epoch)
logger.info('Validation Label Error Rate: {}'.format(dev_ler))
summary_line = self.sess.run(
self.dev_ler_op, {self.ler_placeholder: dev_ler})
self.writer.add_summary(summary_line, epoch)
if dev_ler < self.min_dev_ler:
self.min_dev_ler = dev_ler
# average historical LER
history_avg_ler = np.mean(self.AVG_VALIDATION_LERS)
# if this LER is not better than average of previous epochs, exit
if history_avg_ler - dev_ler <= self.CURR_VALIDATION_LER_DIFF:
log = "Validation label error rate not improved by more than {:.2%} \
after {} epochs. Exit"
warnings.warn(log.format(self.CURR_VALIDATION_LER_DIFF,
self.AVG_VALIDATION_LER_EPOCHS))
# save avg validation accuracy in the next slot
self.AVG_VALIDATION_LERS[
epoch % self.AVG_VALIDATION_LER_EPOCHS] = dev_ler
由上面代码可以看出验证主要使用self.data_sets.dev中数据,如果验证错误率不比前面的平均错误率高的话,给出相关的warning。
至此,整个训练过程的代码我们都分析完了,还有疑问的是对输出向量decode的时候调用的sparse_tuple_to_texts函数和ndarray_to_text函数还没有分析。我们留待下回细细分解。