菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练

系列目录:

  1. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)——数据
  2. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)——
    介绍及分词
  3. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理
  4. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(四)—— 段落抽取
  5. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(五)—— 准备数据
  6. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(六)—— 模型构建
  7. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(七)—— 模型训练-数据准备

未完待续 … …

上一篇文章对模型的训练时的数据准备进行了介绍,将数据彻底转换为模型输入的批次后,就可以调用模型输入数据进行训练了。

rc_model.train

由上篇文章介绍的run.py文件中的train函数可以看到,其在准备好数据之后,调用了rc_model.train函数对模型进行了训练,具体实现如下:

def train(self, data, epochs, batch_size, save_dir, save_prefix,
          dropout_keep_prob=1.0, evaluate=True):
    """
    使用数据训练模型
    参数:
        data: BRCDataset类
        epochs: 训练次数
        batch_size:批次大小
        save_dir: 模型保存目录
        save_prefix: 前缀代表了模型类别
        dropout_keep_prob: 代表dropout保留概率的浮点数
        evaluate: 每次训练后是否使用测试集评估模型
    """
    pad_id = self.vocab.get_id(self.vocab.pad_token)
    max_bleu_4 = 0
    # 训练次数循环
    for epoch in range(1, epochs + 1):
        # 准备数据,将其分割为批次
        self.logger.info('Training the model for epoch {}'.format(epoch))
        train_batches = data.gen_mini_batches('train', batch_size, pad_id, shuffle=True)
        # 训练一个周期
        train_loss = self._train_epoch(train_batches, dropout_keep_prob)
        self.logger.info('Average train loss for epoch {} is {}'.format(epoch, train_loss))
        # 验证
        if evaluate:
            self.logger.info('Evaluating the model after epoch {}'.format(epoch))
            if data.dev_set is not None:
                # 准备验证数据
                eval_batches = data.gen_mini_batches('dev', batch_size, pad_id, shuffle=False)
                # 验证
                eval_loss, bleu_rouge = self.evaluate(eval_batches)
                self.logger.info('Dev eval loss {}'.format(eval_loss))
                self.logger.info('Dev eval result: {}'.format(bleu_rouge))

                if bleu_rouge['Bleu-4'] > max_bleu_4:
                    self.save(save_dir, save_prefix)
                    max_bleu_4 = bleu_rouge['Bleu-4']
            else:
                self.logger.warning('No dev set is loaded for evaluation in the dataset!')
        else:
            self.save(save_dir, save_prefix + '_' + str(epoch))

由代码可知,rc_model.train函数对训练周期进行了迭代,并通过调用_train_epoch完成一个周期的训练,另外如果设置了对模型进行校验其将调用evaluate对模型进行验证并计算了校验损失、bleu_rouge分数。下面首先介绍下_train_epoch

rc_model._train_epoch

rc_model._train_epoch输入数据集的批次生成器、dropout保持概率,完成模型一个周期的训练,具体代码如下:

def _train_epoch(self, train_batches, dropout_keep_prob):
    """
    训练模型一个周期
    参数:
        train_batches: 可迭代的训练批次数据
        dropout_keep_prob: 表示dropout保留概率的浮点数
    """
    total_num, total_loss = 0, 0
    log_every_n_batch, n_batch_loss = 50, 0
    # 遍历所有批次
    for bitx, batch in enumerate(train_batches, 1):
        # 构造批次feed_dict
        feed_dict = {self.p: batch['passage_token_ids'],
                     self.q: batch['question_token_ids'],
                     self.p_length: batch['passage_length'],
                     self.q_length: batch['question_length'],
                     self.start_label: batch['start_id'],
                     self.end_label: batch['end_id'],
                     self.dropout_keep_prob: dropout_keep_prob}
        # 训练
        _, loss = self.sess.run([self.train_op, self.loss], feed_dict)
        # 计算损失并打印
        total_loss += loss * len(batch['raw_data'])
        total_num += len(batch['raw_data'])
        n_batch_loss += loss
        if log_every_n_batch > 0 and bitx % log_every_n_batch == 0:
            self.logger.info('Average loss from batch {} to {} is {}'.format(
                bitx - log_every_n_batch + 1, bitx, n_batch_loss / log_every_n_batch))
            n_batch_loss = 0
    return 1.0 * total_loss / total_num

由代码可以看到,rc_model._train_epoch函数构造了feed_dict并调用sess.run对模型进行了训练,最终计算并返回了训练损失。

rc_model.evaluate

rc_model.train中使用了evaluate函数对模型的效果进行验证,其具体代码如下:

def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False):
    """
    在验证批次上对模型效果进行验证,如果设定将会保存结果
    Args:
        eval_batches: 可以迭代的批次数据
        result_dir: 保存预测答案的路径,如果为空则不保存结果
        result_prefix: 保存预测结果文件的前缀,如果为空则不保存
        save_full_info: 如果为真,预测结果将被添加到原始样本并保存
    """
    pred_answers, ref_answers = [], []
    total_loss, total_num = 0, 0
    # 遍历批次
    for b_itx, batch in enumerate(eval_batches):
        # 生成feed_dict
        feed_dict = {self.p: batch['passage_token_ids'],
                     self.q: batch['question_token_ids'],
                     self.p_length: batch['passage_length'],
                     self.q_length: batch['question_length'],
                     self.start_label: batch['start_id'],
                     self.end_label: batch['end_id'],
                     self.dropout_keep_prob: 1.0}
        # 使用模型预测,输出为起始索引、终止概率分布,损失
        start_probs, end_probs, loss = self.sess.run([self.start_probs,
                                                      self.end_probs, self.loss], feed_dict)

        total_loss += loss * len(batch['raw_data'])
        total_num += len(batch['raw_data'])

        padded_p_len = len(batch['passage_token_ids'][0])
        # 遍历样本
        for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs, end_probs):
            #根据样本、概率分布、文档长度计算最佳答案
            best_answer = self.find_best_answer(sample, start_prob, end_prob, padded_p_len)
            # 是否保存全部信息
            if save_full_info:
                sample['pred_answers'] = [best_answer]
                pred_answers.append(sample)
            else:
                pred_answers.append({'question_id': sample['question_id'],
                                     'question_type': sample['question_type'],
                                     'answers': [best_answer],
                                     'entity_answers': [[]],
                                     'yesno_answers': []})
            if 'answers' in sample:
                ref_answers.append({'question_id': sample['question_id'],
                                     'question_type': sample['question_type'],
                                     'answers': sample['answers'],
                                     'entity_answers': [[]],
                                     'yesno_answers': []})
    # 保存结果到文件
    if result_dir is not None and result_prefix is not None:
        result_file = os.path.join(result_dir, result_prefix + '.json')
        with open(result_file, 'w') as fout:
            for pred_answer in pred_answers:
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')

        self.logger.info('Saving {} results to {}'.format(result_prefix, result_file))

    # 这里的平均损失对于测试集无效,因为没有真实的起始索引和终止索引
    ave_loss = 1.0 * total_loss / total_num
    # 如果有参考答案计算结果的bleu和rouge分数
    if len(ref_answers) > 0:
        pred_dict, ref_dict = {}, {}
        for pred, ref in zip(pred_answers, ref_answers):
            question_id = ref['question_id']
            if len(ref['answers']) > 0:
                pred_dict[question_id] = normalize(pred['answers'])
                ref_dict[question_id] = normalize(ref['answers'])
        bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
    else:
        bleu_rouge = None
    return ave_loss, bleu_rouge

由代码可见,rc_model.evaluate函数调用模型生成答案起始、终止索引概率分布start_probsend_probs后调用了rc_model.find_best_answer函数将概率分布转化为真实答案,并保存了结果。

注意:run.py中的预测、校验函数predict(run.py)evaluate(run.py)也是调用了rc_model.evaluate函数完成了预测及校验。

rc_model.find_best_answer

rc_model.find_best_answer用于将模型输出的起始索引、终止索引的概率分布,转化为对应文本,其具体代码如下:

def find_best_answer(self, sample, start_prob, end_prob, padded_p_len):
    """
    给定样本的每个位置是起始索引概率和终止索引概率的分布,找到样本的最佳答案。
    由于一个样本中有多个段落,其将调用find_best_answer_for_passage,为每个段落找出最佳答案。
    参数:
        sample:样本
        start_prob:答案起始索引概率分布
        end_prob:答案终止索引概率分布
        padded_p_len:填充后段落长度
    """
    best_p_idx, best_span, best_score = None, None, 0
    # 遍历段落
    for p_idx, passage in enumerate(sample['passages']):
        if p_idx >= self.max_p_num:
            continue
        passage_len = min(self.max_p_len, len(passage['passage_tokens']))
        # 输入每个段落对应的起始、终止索引概率分布,为每个段落寻找最佳答案,同时输出答案分数
        answer_span, score = self.find_best_answer_for_passage(
            start_prob[p_idx * padded_p_len: (p_idx + 1) * padded_p_len],
            end_prob[p_idx * padded_p_len: (p_idx + 1) * padded_p_len],
            passage_len)
        # 取分数最大的答案作为样本的最佳答案
        if score > best_score:
            best_score = score
            best_p_idx = p_idx
            best_span = answer_span
    if best_p_idx is None or best_span is None:
        best_answer = ''
    else:
        best_answer = ''.join(
            sample['passages'][best_p_idx]['passage_tokens'][best_span[0]: best_span[1] + 1])
    return best_answer

rc_model.find_best_answer函数通过调用rc_model.find_best_answer_for_passage为每个段落求解了最佳答案并得到了最佳答案的分数。其取其中分数最高的答案作为问题的最终答案。

rc_model.find_best_answer_for_passage

rc_model.find_best_answer_for_passage可以根据输入的起始索引、终止索引的概率分布为样本中的每个段落找到最佳的答案及答案分数,返回给rc_model.find_best_answer函数,用来判断哪个答案是整个样本的最佳答案,其代码如下:

def find_best_answer_for_passage(self, start_probs, end_probs, passage_len=None):
    """
    输入单个段落对应的起始、终止索引概率分布,从段落中找到具有最大的起始索引概率*终止索引概率的最佳答案
    """
    if passage_len is None:
        passage_len = len(start_probs)
    else:
        passage_len = min(len(start_probs), passage_len)
    best_start, best_end, max_prob = -1, -1, 0
    # 遍历文档作为起始索引
    for start_idx in range(passage_len):
        # 遍历答案长度,小于最大答案长度
        for ans_len in range(self.max_a_len):
            end_idx = start_idx + ans_len
            if end_idx >= passage_len:
                continue
            # 保留start_probs*end_probs最大的范围为最佳答案,其乘积为答案分数
            prob = start_probs[start_idx] * end_probs[end_idx]
            if prob > max_prob:
                best_start = start_idx
                best_end = end_idx
                max_prob = prob
    return (best_start, best_end), max_prob

rc_model.find_best_answer_for_passage遍历整个段落的答案索引概率,从中找到长度小于最大答案长度的start_probs*end_probs最大的片段作为本段落的最佳答案,start_probs*end_probs为该答案的评分。

运行

import

run.py文件中导入训练所需的函数。

import sys
import pickle
from run import *
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

args

创建参数

sys.argv = []
args = parse_args()
print(args)
Namespace(algo='BIDAF', batch_size=32, brc_dir='../data/baidu', dev_files=['../data/demo/devset/search.dev.json'], dropout_keep_prob=1, embed_size=300, epochs=10, evaluate=False, gpu='0', hidden_size=150, learning_rate=0.001, log_path=None, max_a_len=200, max_p_len=500, max_p_num=5, max_q_len=60, model_dir='../data/models/', optim='adam', predict=False, prepare=False, result_dir='../data/results/', summary_dir='../data/summary/', test_files=['../data/demo/testset/search.test.json'], train=False, train_files=['../data/demo/trainset/search.train.json'], vocab_dir='../data/vocab/', weight_decay=0)
# 设置模型训练的环境变量,使用GPU进行训练。
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

train

import logging
logger = logging.getLogger("brc")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if args.log_path:
    file_handler = logging.FileHandler(args.log_path)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
else:
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

logger.info('Running with args : {}'.format(args))

打印参数:

2020-03-27 22:05:44,445 - brc - INFO - Running with args : Namespace(algo='BIDAF', batch_size=32, brc_dir='../data/baidu', dev_files=['../data/demo/devset/search.dev.json'], dropout_keep_prob=1, embed_size=300, epochs=10, evaluate=False, gpu='0', hidden_size=150, learning_rate=0.001, log_path=None, max_a_len=200, max_p_len=500, max_p_num=5, max_q_len=60, model_dir='../data/models/', optim='adam', predict=False, prepare=False, result_dir='../data/results/', summary_dir='../data/summary/', test_files=['../data/demo/testset/search.test.json'], train=False, train_files=['../data/demo/trainset/search.train.json'], vocab_dir='../data/vocab/', weight_decay=0)
train(args)

输出:

2020-03-27 22:05:45,096 - brc - INFO - Load data_set and vocab...
2020-03-27 22:05:45,178 - brc - INFO - Train set size: 95 questions.
2020-03-27 22:05:45,299 - brc - INFO - Dev set size: 100 questions.
2020-03-27 22:05:45,300 - brc - INFO - Converting text into ids...
2020-03-27 22:05:45,343 - brc - INFO - Initialize the model...
2020-03-27 22:05:48,926 - brc - INFO - Time to build graph: 3.3801302909851074 s
2020-03-27 22:05:57,149 - brc - INFO - There are 4995603 parameters in the model

2020-03-27 22:05:58,249 - brc - INFO - Training the model...
2020-03-27 22:05:58,250 - brc - INFO - Training the model for epoch 1
2020-03-27 22:06:06,592 - brc - INFO - Average train loss for epoch 1 is 14.488333722164757
2020-03-27 22:06:06,593 - brc - INFO - Evaluating the model after epoch 1
2020-03-27 22:06:11,522 - brc - INFO - Dev eval loss 14.235187187194825
2020-03-27 22:06:11,523 - brc - INFO - Dev eval result: {'Bleu-1': 2.8994602314302886e-24, 'Bleu-2': 2.094927251478452e-24, 'Bleu-3': 1.666268827061724e-24, 'Bleu-4': 1.460637092678813e-24, 'Rouge-L': 0.041639028266235965}
2020-03-27 22:06:12,234 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:06:12,236 - brc - INFO - Training the model for epoch 2
2020-03-27 22:06:17,464 - brc - INFO - Average train loss for epoch 2 is 12.861935404727333
2020-03-27 22:06:17,465 - brc - INFO - Evaluating the model after epoch 2
2020-03-27 22:06:21,883 - brc - INFO - Dev eval loss 12.90282627105713
2020-03-27 22:06:21,884 - brc - INFO - Dev eval result: {'Bleu-1': 1.3940577651068293e-21, 'Bleu-2': 1.0127594865890665e-21, 'Bleu-3': 7.2261144877553375e-22, 'Bleu-4': 6.311056727542682e-22, 'Rouge-L': 0.03253047012256154}
2020-03-27 22:06:22,390 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:06:22,391 - brc - INFO - Training the model for epoch 3
2020-03-27 22:06:27,590 - brc - INFO - Average train loss for epoch 3 is 11.300762738679584
2020-03-27 22:06:27,591 - brc - INFO - Evaluating the model after epoch 3
2020-03-27 22:06:32,021 - brc - INFO - Dev eval loss 13.704137229919434
2020-03-27 22:06:32,023 - brc - INFO - Dev eval result: {'Bleu-1': 2.338105788459333e-11, 'Bleu-2': 1.8041982946045808e-11, 'Bleu-3': 1.5392164591567576e-11, 'Bleu-4': 1.387615276811719e-11, 'Rouge-L': 0.0429180965585989}
2020-03-27 22:06:32,524 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:06:32,525 - brc - INFO - Training the model for epoch 4
2020-03-27 22:06:37,790 - brc - INFO - Average train loss for epoch 4 is 10.787741811651932
2020-03-27 22:06:37,792 - brc - INFO - Evaluating the model after epoch 4
2020-03-27 22:06:42,463 - brc - INFO - Dev eval loss 13.255349464416504
2020-03-27 22:06:42,464 - brc - INFO - Dev eval result: {'Bleu-1': 0.000134217134765812, 'Bleu-2': 0.00010227243606088728, 'Bleu-3': 8.221094680997049e-05, 'Bleu-4': 6.89352421323188e-05, 'Rouge-L': 0.09043521845567885}
2020-03-27 22:06:42,944 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:06:42,945 - brc - INFO - Training the model for epoch 5
2020-03-27 22:06:48,167 - brc - INFO - Average train loss for epoch 5 is 10.117159582439221
2020-03-27 22:06:48,168 - brc - INFO - Evaluating the model after epoch 5
2020-03-27 22:06:52,840 - brc - INFO - Dev eval loss 13.418038635253906
2020-03-27 22:06:52,841 - brc - INFO - Dev eval result: {'Bleu-1': 0.03649500805085331, 'Bleu-2': 0.026348497873034213, 'Bleu-3': 0.020451816428277064, 'Bleu-4': 0.016763779286368262, 'Rouge-L': 0.12265401451615979}
2020-03-27 22:06:53,376 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:06:53,378 - brc - INFO - Training the model for epoch 6
2020-03-27 22:06:58,576 - brc - INFO - Average train loss for epoch 6 is 9.720805951168662
2020-03-27 22:06:58,578 - brc - INFO - Evaluating the model after epoch 6
2020-03-27 22:07:03,585 - brc - INFO - Dev eval loss 13.420298271179199
2020-03-27 22:07:03,586 - brc - INFO - Dev eval result: {'Bleu-1': 0.17114988944235895, 'Bleu-2': 0.12398018542779535, 'Bleu-3': 0.09726068255081648, 'Bleu-4': 0.08122328728599948, 'Rouge-L': 0.17611418531084605}
2020-03-27 22:07:04,096 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:07:04,097 - brc - INFO - Training the model for epoch 7
2020-03-27 22:07:09,266 - brc - INFO - Average train loss for epoch 7 is 9.362699759633918
2020-03-27 22:07:09,267 - brc - INFO - Evaluating the model after epoch 7
2020-03-27 22:07:14,209 - brc - INFO - Dev eval loss 13.428105583190918
2020-03-27 22:07:14,210 - brc - INFO - Dev eval result: {'Bleu-1': 0.20241258517854588, 'Bleu-2': 0.14444878311065115, 'Bleu-3': 0.1121842519617245, 'Bleu-4': 0.0929297803274465, 'Rouge-L': 0.19206336673714766}
2020-03-27 22:07:14,716 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:07:14,717 - brc - INFO - Training the model for epoch 8
2020-03-27 22:07:19,957 - brc - INFO - Average train loss for epoch 8 is 9.032782383968955
2020-03-27 22:07:19,959 - brc - INFO - Evaluating the model after epoch 8
2020-03-27 22:07:24,929 - brc - INFO - Dev eval loss 13.415868530273437
2020-03-27 22:07:24,931 - brc - INFO - Dev eval result: {'Bleu-1': 0.17880382697329436, 'Bleu-2': 0.12925351813739897, 'Bleu-3': 0.10108439400963827, 'Bleu-4': 0.08416185599631172, 'Rouge-L': 0.18904312656365377}

2020-03-27 22:07:24,931 - brc - INFO - Training the model for epoch 9
2020-03-27 22:07:30,186 - brc - INFO - Average train loss for epoch 9 is 8.598431963669626
2020-03-27 22:07:30,187 - brc - INFO - Evaluating the model after epoch 9
2020-03-27 22:07:35,089 - brc - INFO - Dev eval loss 13.766514587402344
2020-03-27 22:07:35,090 - brc - INFO - Dev eval result: {'Bleu-1': 0.19132287017014682, 'Bleu-2': 0.14104624846358704, 'Bleu-3': 0.11170956630975126, 'Bleu-4': 0.09381119777647412, 'Rouge-L': 0.19897168820075256}
2020-03-27 22:07:35,627 - brc - INFO - Model saved in ../data/models/, with prefix BIDAF.

2020-03-27 22:07:35,628 - brc - INFO - Training the model for epoch 10
2020-03-27 22:07:40,864 - brc - INFO - Average train loss for epoch 10 is 8.178678597901996
2020-03-27 22:07:40,866 - brc - INFO - Evaluating the model after epoch 10
2020-03-27 22:07:45,761 - brc - INFO - Dev eval loss 14.380211639404298
2020-03-27 22:07:45,762 - brc - INFO - Dev eval result: {'Bleu-1': 0.16786192874911482, 'Bleu-2': 0.1256542560373232, 'Bleu-3': 0.10173408122289541, 'Bleu-4': 0.08729074121381147, 'Rouge-L': 0.19595328034391185}
2020-03-27 22:07:45,763 - brc - INFO - Done with model training!

以上是使用Demo数据进行训练过程中产生的输出(处理了一下,有些警告之类的删了),数据量较小时训练速度还是比较快的,从输出可以看到训练进行了10个epoch,训练损失在逐步变小,而校验损失变化不明显,其他校验参数Bleu-nRouge-L的变化趋势也不是太明显,这是由于示例数据集数据量太小,不能训练出有效的模型参数的原因,如果使用全数据集进行训练会看到比较明显的趋势变化。

参考文献:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值