1 任务介绍和前期准备
任务的背景如下
- 本次任务使用了MRQA-shared-task中的train和dev数据,其中包含了常见的QA数据库,例如SQuAD,NewsQA,SearchQA,HotpotQA等。
- 预训练模型是huawei-noah/TinyBERT_General_6L_768D
- 训练数据集是
HotpotQA
。 - 运行环境 Google Colab (Pro)详细性能配置可以见本文章
- Model的运行代码
1.1 下载第三方库
安装Transformer和Huggingface
!pip install transformers
!pip install datasets
!pip install huggingface_hub
所使用的第三方类
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW, get_scheduler
from datasets import load_dataset, Dataset, DatasetDict, load_metric
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from matplotlib import pyplot as plt
import pandas as pd
import gzip
import json
import numpy as np
import os
加载与预训练的模型和tokenizer,此处使用的args
是一个包含训练参数的字典,这里的配置也是得到本次任务最佳模型的训练参数:
args={
"DATASET_PATHS":[{
"TRAIN":"datasets/train/HotpotQA.jsonl.gz",
"IN_DOMAIN_DEV":"datasets/in_domain_dev/HotpotQA.jsonl.gz",
"OUT_DOMAIN_DEV":"datasets/out_domain_dev/HotpotQA.jsonl.gz",
}],
'MODEL':'huawei-noah/TinyBERT_General_6L_768D',
'EPOCHS': 5,
'VAL_BATCH_SIZE':16,
'TRAINING_BATCH_SIZE':16,
'LEARNING_RATE':2e-5,
'MAX_SIZE':256,
}
args['DEVICE'] = torch.device('cuda')
model = AutoModelForQuestionAnswering.from_pretrained(args.get('MODEL')).to(args.get('device'))
tokenizer=AutoTokenizer.from_pretrained(args.get('MODEL'))
1.2 从sharetask中下载数据
在这次的sharetask1中,作者准备好了自动下载所有训练数据据的脚本, 我们只需要将该仓库克隆到Colab中,然后再 运行该脚本 即可:
!git clone https://github.com/mrqa/MRQA-Shared-Task-2019.git
!bash MRQA-Shared-Task-2019/download_train.sh 'datasets/train'
!bash MRQA-Shared-Task-2019/download_in_domain_dev.sh 'datasets/in_domain_dev'
!bash MRQA-Shared-Task-2019/download_in_domain_dev.sh 'datasets/out_domain_dev'
1.3 加载原始数据
下载完成后,文件树如下,可以观察到各个数据集是jsonl格式的文档的gz
格式压缩。
先将gz文件用gzip
打开,然后用json.load
读取每个文件:
def read(self,file_path):
rawdata = []
with gzip.open(file_path, 'rb') as myzip:
for example in myzip:
context = json.loads(example)
if 'header' in context:
continue
rawdata.append(context)
return rawdata
读取后的结果是一个字典的list,每个字典的结构包括 dict_keys(['id', 'context', 'qas', 'context_tokens'])
。 读取完原始数据之后,由于本次是QA任务,所以只需要以下三个key的内容2:
context
相关文本question
根据context提出的的问题,其属于qas
的子结构answers
其属于qas
的子结构。包括text
:答案的文本,answer_start
: 答案在context中的位置
我写了一个Reader
类,其功能包含了以上描述的读取原始数据,提取所需的key,以及将其封装成datasets
类,最后dataset的数据格式如下:
Reader
以及本文完整的代码我会放入Notebook中,上传到Github。
小规模训练中,我从trainset中随机选择5000条数据作为训练数据集,从devset中随机500条作为validation
SEED=123
train_text_dataset<