Huggingface入门篇 II (QA)

1 任务介绍和前期准备

任务的背景如下

1.1 下载第三方库

安装Transformer和Huggingface

!pip install transformers
!pip install datasets
!pip install huggingface_hub

所使用的第三方类

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW, get_scheduler
from datasets import load_dataset, Dataset, DatasetDict, load_metric
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from matplotlib import pyplot as plt
import pandas as pd
import gzip
import json
import numpy as np
import os 

加载与预训练的模型和tokenizer,此处使用的args是一个包含训练参数的字典,这里的配置也是得到本次任务最佳模型的训练参数:

args={
   
    "DATASET_PATHS":[{
   
      "TRAIN":"datasets/train/HotpotQA.jsonl.gz",
      "IN_DOMAIN_DEV":"datasets/in_domain_dev/HotpotQA.jsonl.gz",
      "OUT_DOMAIN_DEV":"datasets/out_domain_dev/HotpotQA.jsonl.gz",
    }],
    'MODEL':'huawei-noah/TinyBERT_General_6L_768D',
    'EPOCHS': 5,
    'VAL_BATCH_SIZE':16,
    'TRAINING_BATCH_SIZE':16,
    'LEARNING_RATE':2e-5,
    'MAX_SIZE':256,
}
args['DEVICE'] = torch.device('cuda')

model = AutoModelForQuestionAnswering.from_pretrained(args.get('MODEL')).to(args.get('device'))
tokenizer=AutoTokenizer.from_pretrained(args.get('MODEL'))

1.2 从sharetask中下载数据

在这次的sharetask1中,作者准备好了自动下载所有训练数据据的脚本, 我们只需要将该仓库克隆到Colab中,然后再 运行该脚本 即可:

!git clone https://github.com/mrqa/MRQA-Shared-Task-2019.git
!bash MRQA-Shared-Task-2019/download_train.sh 'datasets/train'
!bash MRQA-Shared-Task-2019/download_in_domain_dev.sh 'datasets/in_domain_dev'
!bash MRQA-Shared-Task-2019/download_in_domain_dev.sh 'datasets/out_domain_dev'

1.3 加载原始数据

下载完成后,文件树如下,可以观察到各个数据集是jsonl格式的文档的gz格式压缩。
在这里插入图片描述
先将gz文件用gzip打开,然后用json.load读取每个文件:

  def read(self,file_path):
    rawdata = []
    with gzip.open(file_path, 'rb') as myzip:
        for example in myzip:
            context = json.loads(example)
            if 'header' in context:
                continue
            rawdata.append(context)
    return rawdata

读取后的结果是一个字典的list,每个字典的结构包括 dict_keys(['id', 'context', 'qas', 'context_tokens'])。 读取完原始数据之后,由于本次是QA任务,所以只需要以下三个key的内容2

  • context 相关文本
  • question 根据context提出的的问题,其属于qas的子结构
  • answers 其属于qas的子结构。包括 text:答案的文本, answer_start: 答案在context中的位置

我写了一个Reader类,其功能包含了以上描述的读取原始数据,提取所需的key,以及将其封装成datasets类,最后dataset的数据格式如下:
在这里插入图片描述
Reader以及本文完整的代码我会放入Notebook中,上传到Github。

小规模训练中,我从trainset中随机选择5000条数据作为训练数据集,从devset中随机500条作为validation

SEED=123
train_text_dataset<
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值