到目前为止,问答模型的代码已经全部写完并可以跑通啦,所以此篇以及后续的n篇里,内容会集中在复盘+细节修正上。
服务器选择
paddlepaddle和paddlenlp已经在自己的电脑上配置完成,但由于个人PC算力有限,在AI Studio上使用模拟我的8G CPU无GPU的环境进行模型训练时,无论如何调节batchsize,内存始终溢出。为了模型训练的效果,最终还是选择了在AI Studio上对模型进行训练。
同样,后续使用模型进行预测时,在数据预处理过程以及后续预测过程都出现了内存溢出的情况,因此拟用AI Studio作为服务器进行后续的预测。
数据集加载
PaddleNLP已经内置SQuAD,CMRC等中英文阅读理解数据集,使用paddlenlp.datasets.load_dataset()API即可一键加载。参考实例加载的是DuReaderRobust中文阅读理解数据集。由于DuReaderRobust数据集采用SQuAD数据格式,InputFeature使用滑动窗口的方法生成,即一个example可能对应多个InputFeature。
答案抽取任务即根据输入的问题和文章,预测答案在文章中的起始位置和结束位置。
由于文章加问题的文本长度可能大于max_seq_length,答案出现的位置有可能出现在文章最后,所以不能简单的对文章进行截断。
那么对于过长的文章,则采用滑动窗口将文章分成多段,分别与问题组合。再用对应的tokenizer转化为模型可接受的feature。doc_stride参数就是每次滑动的距离。滑动窗口生成InputFeature的过程如下图:
对代码进行修改以适配我们的任务
参考示例中通过在线下载DuReaderRobust中文阅读理解数据集对模型进行训练。但该数据获取方法无法适配我们的任务。
我们的问答模型需要更灵活的对query和context进行读取,因此首先要将数据集读取的方式由在线下载改为本地
数据处理
原始数据格式:
{
'documents': [{
'is_selected': True or Flase,
'title': ' String',
'paragraphs':[ ' String', ' String', ' String'
]
},
{
'is_selected': True or Flase,
'title': ' String',
'paragraphs':[ ' String', ' String', ' String'
]
}
],
'answers': [' String', ' String', ' String'
],
'question’: ‘String',
'question_type': DESCRIPTION or ENTITY or YESNO,
'fact_or_opinion': FACT or OPINION,
'question_id': 191572
}
在这一步时出现了很多bug,绕了很多弯路
因为既想采用dureader_robust提供的特征提取方法,又需要使用本地的数据进行训练和预测,因此要在第一步时就将数据处理成与下载的数据相同的格式。
从json文件读取数据
def read(data_path):
"""This function returns the examples in the raw (text) form."""
key = 0
with open(data_path, encoding="utf-8") as f:
durobust = json.load(f)
for article in durobust["data"]:
title = article.get("title", "")
for paragraph in article["paragraphs"]:
context = paragraph[
"context"] # do not strip leading blank spaces GH-2585
for qa in paragraph["qas"]:
answer_starts = [
answer["answer_start"]
for answer in qa.get("answers", '')
]
answers = [
answer["text"] for answer in qa.get("answers", '')
]
# Features currently used are "context", "question", and "answers".
# Others are extracted here for the ease of future expansions.
yield key, {
"id": qa["id"],
"title": title,
"context": context,
"question": qa["question"],
"answers": answers,
"answer_starts": answer_starts,
}
key += 1