菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(七)—— 模型训练-数据准备

系列目录:

  1. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)——数据
  2. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)——
    介绍及分词
  3. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(三)—— 预处理
  4. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(四)—— 段落抽取
  5. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(五)—— 准备数据
  6. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(六)—— 模型构建

未完待续 … …

上一篇文章对模型的结构进行了介绍,本文开始介绍训练中的数据准备,数据经过预处理后,到真正输入模型进行训练还需要进一步的处理。

训练主函数

首先来看一下训练的主函数,主函数train如下:

def train(args):
    """
    训练阅读理解模型
    """
    logger = logging.getLogger("brc")
    logger.info('Load data_set and vocab...')
    # 加载字典
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    # 加载数据
    brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                          args.train_files, args.dev_files)
    logger.info('Converting text into ids...')
    # 将数据转换为数字索引ids
    brc_data.convert_to_ids(vocab)
    logger.info('Initialize the model...')
    # 初始化模型
    rc_model = RCModel(vocab, args)
    logger.info('Training the model...')
    # 训练模型
    rc_model.train(brc_data, args.epochs, args.batch_size, save_dir=args.model_dir,
                   save_prefix=args.algo,
                   dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')

有代码可以看到,训练主函数包含了加载词典、加载数据、将数据转换为索引、构建模型、训练模型几部分,本文重点介绍下其中加载数据部分。

BRCDataset

BRCDataset函数在准备数据部分简单介绍过,回顾一下:

类名 BRCDataset
功能:实现加载使用百度阅读理解数据集的APIs
类属性:
    self.max_p_num:最大段落数量
    self.max_p_len:最大段落长度
    self.max_q_len:最大问题长度
    self.train_set, self.dev_set, self.test_set:训练、验证、测试数据集
类主要方法:
	_load_dataset():加载数据,数据集初始化时会自动调用这个函数加载数据
	_one_mini_batch:生成一个batch的数据
	_dynamic_padding:动态填充
	word_iter:遍历数据集中所有单词
	convert_to_ids:将数据集中的文本(问题、文档)转化为ids
	gen_mini_batches:为特定数据集生成batch数据

下面简单介绍其中数据处理的关键函数,其余的大家可以自行阅读源代码。

_load_dataset

_load_dataset函数是在BRCDataset类初始化时自动运行,加载训练、验证、测试数据集数据,其代码如下:

def _load_dataset(self, data_path, train=False):
    """
    加载数据集
    Args:
        data_path: 需要加载的数据集的路径
    """
    with open(data_path) as fin:
        data_set = []
        for lidx, line in enumerate(fin):
            # 开始处理单个样本
            sample = json.loads(line.strip())

            if train:
                if len(sample['answer_spans']) == 0:
                    continue
                if sample['answer_spans'][0][1] >= self.max_p_len:
                    continue
            # 答案所在的文档,后面在_one_mini_batch函数中用于计算答案范围的偏置
            if 'answer_docs' in sample:
                sample['answer_passages'] = sample['answer_docs']
            # 问题
            sample['question_tokens'] = sample['segmented_question']
            # 文档
            sample['passages'] = []
            # 遍历每个样本中的文档
            for d_idx, doc in enumerate(sample['documents']):
                if train:
                    # 如果是训练集,处理相对简单,只取预处理中计算的每个文档的最相关段落将其作为 
                    #`passage_tokens`与`is_selected`组成的字典插入`passages`
                    most_related_para = doc['most_related_para']
                    sample['passages'].append(
                        {'passage_tokens': doc['segmented_paragraphs'][most_related_para],
                         'is_selected': doc['is_selected']}
                    )
                else:
                    # 如果不是训练集,则遍历每个段落,计算段落与问题的recall值,
                    #并按照recall和段落长度排序(短的在前),取前几个段落作为passage_tokens
                    para_infos = []
                    for para_tokens in doc['segmented_paragraphs']:
                        question_tokens = sample['segmented_question']
                        # 计算段落与问题的recall值
                        common_with_question = Counter(para_tokens) & Counter(question_tokens)
                        correct_preds = sum(common_with_question.values())
                        if correct_preds == 0:
                            recall_wrt_question = 0
                        else:
                            recall_wrt_question = float(correct_preds) / len(question_tokens)
                        para_infos.append((para_tokens, recall_wrt_question, len(para_tokens)))
                    para_infos.sort(key=lambda x: (-x[1], x[2]))
                    fake_passage_tokens = []
                    # 取第一个段落作为passage_tokens
                    for para_info in para_infos[:1]:
                        fake_passage_tokens += para_info[0]
                    sample['passages'].append({'passage_tokens': fake_passage_tokens})
            data_set.append(sample)
    return data_set

由代码可见,_load_dataset函数在加载数据的同时对数据集(尤其是校验集和测试集)进行了进一步处理,为样本添加了answer_passagesquestion_tokenspassages字段,其中passages对于训练集是每个文档中与答案最相关段落的列表,对其他数据集是与问题最相关段落的列表。

gen_mini_batches

gen_mini_batches可以为设定的数据集(train/dev/test)生成数据批次,训练中训练代码会调用这个函数来生成训练数据。

def gen_mini_batches(self, set_name, batch_size, pad_id, shuffle=True):
    """
    为设定的数据集(train/dev/test)生成数据批次
    参数:
        set_name: 数据集名称,使用train/dev/test 指明数据集
        batch_size: 每个批次样本的数量
        pad_id: 填充字符索引
        shuffle: 如果值为真,将数据打乱.
    返回值:
        所有批次的生成器
    """
    if set_name == 'train':
        data = self.train_set
    elif set_name == 'dev':
        data = self.dev_set
    elif set_name == 'test':
        data = self.test_set
    else:
        raise NotImplementedError('No data set named as {}'.format(set_name))
    data_size = len(data)
    indices = np.arange(data_size)
    if shuffle:
        np.random.shuffle(indices)
    for batch_start in np.arange(0, data_size, batch_size):
        batch_indices = indices[batch_start: batch_start + batch_size]
        # 根据索引生成一个样本批次
        yield self._one_mini_batch(data, batch_indices, pad_id)

由代码可见,这个函数主要的功能是选择数据集、打乱数据、确定每个批次样本索引,最终每一个批次数据的生成是调用了_one_mini_batch函数。

_one_mini_batch

_one_mini_batch根据输入的数据和所选索引生成一个数据批次,生成时还根据本批次的最长样本和设置的最大长度对这个批次的样本进行填充。

def _one_mini_batch(self, data, indices, pad_id):
    """
    生成一个批次
    参数:
        data: 所有数据
        indices: 所选样本的索引the indices of the samples to be selected
        pad_id:填充字符索引
    返回值:
        一个数据批次
    """
    batch_data = {'raw_data': [data[i] for i in indices],
                  'question_token_ids': [],
                  'question_length': [],
                  'passage_token_ids': [],
                  'passage_length': [],
                  'start_id': [],
                  'end_id': []}
    # 最大段落数量
    max_passage_num = max([len(sample['passages']) for sample in batch_data['raw_data']])
    max_passage_num = min(self.max_p_num, max_passage_num)
    for sidx, sample in enumerate(batch_data['raw_data']):
        # 遍历1到`max_passage_num`
        for pidx in range(max_passage_num):
            # 如果pidx小于段落数量,即有样本,将样本值赋给batch_data的对应字段
            if pidx < len(sample['passages']):
                batch_data['question_token_ids'].append(sample['question_token_ids'])
                batch_data['question_length'].append(len(sample['question_token_ids']))
                passage_token_ids = sample['passages'][pidx]['passage_token_ids']
                batch_data['passage_token_ids'].append(passage_token_ids)
                batch_data['passage_length'].append(min(len(passage_token_ids), self.max_p_len))
            # 如果没有样本,插入空样本
            else:
                batch_data['question_token_ids'].append([])
                batch_data['question_length'].append(0)
                batch_data['passage_token_ids'].append([])
                batch_data['passage_length'].append(0)
    # 动态填充批次数据,返回样本长度对齐的批次,及填充后的段落、问题长度
    batch_data, padded_p_len, padded_q_len = self._dynamic_padding(batch_data, pad_id)
    for sample in batch_data['raw_data']:
        if 'answer_passages' in sample and len(sample['answer_passages']):
            # 计算答案所在段落偏移,sample['answer_passages'][0]在_load_dataset中创建,是答案所在文档的索引
            gold_passage_offset = padded_p_len * sample['answer_passages'][0]
            # 根据偏移计算答案的起始索引和终止索引
            batch_data['start_id'].append(gold_passage_offset + sample['answer_spans'][0][0])
            batch_data['end_id'].append(gold_passage_offset + sample['answer_spans'][0][1])
        else:
            # 如果没有答案插入0
            batch_data['start_id'].append(0)
            batch_data['end_id'].append(0)
    return batch_data

由代码可以看到,这个函数功能如下:

  1. 根据批次索引列表读取该批次的数据并存入raw_data字段。
  2. 统计文档最大段落数目,如果超过设定值,取设定值。
  3. 将所有文档段落按照最大段落长度进行统一,多的删除,少的补空文档。
  4. 调用_dynamic_padding函数对每一个段落进行填充操作,根据最大段落长度,截取或填充。
  5. 根据答案所在的段落索引及段落长度计算答案索引偏移量,并计算新的答案索引。
  6. 返回batch_data
    所以最终返回的数据是文档中段落数目一致(不超过预设最大段落数目的统一值),段落长度一致(不超过预设最大段落长度的统一值)。

简单调用

import

import sys
import pickle
from run import *
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

args

sys.argv = []
args = parse_args()
print(args)
Namespace(algo='BIDAF', batch_size=32, brc_dir='../data/baidu', dev_files=['../data/demo/devset/search.dev.json'], dropout_keep_prob=1, embed_size=300, epochs=10, evaluate=False, gpu='0', hidden_size=150, learning_rate=0.001, log_path=None, max_a_len=200, max_p_len=500, max_p_num=5, max_q_len=60, model_dir='../data/models/', optim='adam', predict=False, prepare=False, result_dir='../data/results/', summary_dir='../data/summary/', test_files=['../data/demo/testset/search.test.json'], train=False, train_files=['../data/demo/trainset/search.train.json'], vocab_dir='../data/vocab/', weight_decay=0)
# 创建数据集
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
                      args.train_files, args.dev_files)
# 打开词典
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    vocab = pickle.load(fin)
    
# 将样本文本转化为索引ids,并添加到数据集
brc_data.convert_to_ids(vocab)

# 准备参数,生成一个大小为4的批次
import numpy as np
data = brc_data.train_set
data_size = len(data)
indices = np.arange(data_size)
pad_id = vocab.get_id(vocab.pad_token)
batch_start = 0
batch_size = 4
batch_indices = indices[batch_start: batch_start + batch_size]
batch = brc_data._one_mini_batch(data, batch_indices,pad_id)
batch.keys()
dict_keys(['raw_data', 'question_token_ids', 'question_length', 'passage_token_ids', 'passage_length', 'start_id', 'end_id'])

由输出可见batch包含了以下字段:

  • raw_data:原始数据
  • question_token_ids:问题符号索引
  • question_length:问题长度列表
  • passage_token_ids:文档符号索引
  • passage_length:文档长度列表
  • start_id:答案起始索引
  • end_id:答案终止索引

其具体值如下:

print(batch['question_token_ids'])
print(np.shape(batch['question_token_ids']))
print(batch['question_length'])
print(np.shape(batch['passage_token_ids']))
print(batch['passage_length'])
print(batch['start_id'])
print(batch['end_id'])
[[2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [2, 3, 4, 5, 6], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [158, 31, 159, 26, 160], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [437, 26, 438, 439, 440], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0], [619, 1, 0, 0, 0]]
(20, 5)
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2, 2, 2, 2, 2]
(20, 443)
[96, 147, 17, 51, 114, 31, 226, 12, 51, 443, 29, 204, 82, 279, 57, 404, 328, 118, 133, 27]
[1772, 4, 5, 445]
[1882, 11, 28, 601]

有输出可见,对于一个样本数量为4的批次,问题数据维度为(20, 5),文档数据维度为(20, 443),所以问题与文档的数量都是20,答案数量为4。这是因为加载数据时,代码根据预先设定的最大文档数量5,将每个样本的文档数量填充(空文档)为5个,同时将每个问题复制了5次,因此每个样本对应5个文档及问题。

另外可以看到一个批次中,所有问题与文档都被填充成相同的长度,长度大小取该批次所有文档(问题)的最大长度与预先设定的文档(文本)最大长度中较小的值。

参考文献:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值