解决BERT官方模型script脚本下载GLUE data失败的问题

解决BERT官方模型script脚本下载GLUE data失败的问题

1. 运行脚本下载GLUE数据集(建议科学上网)

BERT的提出对于NLP领域是一项里程碑。再一次证明了利用语言模型的预训练然后再针对具体下游任务微调是NLP研究近些年来的优秀做法。
在github上下载BERT官方模型需要下载的BLUE data数据集时,可能失败,因为可能由于模型开发时python版本较低,导致导包出现问题,下面给出python3以上版本下载GLUE data数据集的代码👇

import io
import os
import sys
import shutil
import argparse
import tempfile
import urllib
import urllib.request
import zipfile

TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic", "MRPC"]
TASK2PATH = {"CoLA": 'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',
             "SST": 'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
             "QQP": 'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',
             "STS": 'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',
             "MNLI": 'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',
             "QNLI": 'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',
             "RTE": 'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',
             "WNLI": 'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',
             "diagnostic": 'https://dl.fbaipublicfiles.com/glue/data/AX.tsv',
             "MRPC": "https://raw.githubusercontent.com/MegEngine/Models/master/official/nlp/bert/glue_data/MRPC/dev_ids.tsv"}

MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'

def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    if task == "MNLI":
        print(
            "\tNote (12/10/20): This script no longer downloads SNLI. You will need to manually download and format the data to use SNLI.")
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

def format_mrpc(data_dir, path_to_data):
    print("Processing MRPC...")
    mrpc_dir = os.path.join(data_dir, "MRPC")
    if not os.path.isdir(mrpc_dir):
        os.mkdir(mrpc_dir)
    if path_to_data:
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    else:
        try:
            mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
            mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
            urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
            urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
        except urllib.error.HTTPError:
            print("Error downloading MRPC")
            return
    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file

    with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \
            io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split('\t')
            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))

    try:
        urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
    except KeyError or urllib.error.HTTPError:
        print("\tError downloading standard development IDs for MRPC. You will need to manually split your data.")
        return

    dev_ids = []
    with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split('\t'))

    with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \
            io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \
            io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split('\t')
            if [id1, id2] in dev_ids:
                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
            else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

    print("\tCompleted!")

def download_diagnostic(data_dir):
    print("Downloading and extracting diagnostic...")
    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
        os.mkdir(os.path.join(data_dir, "diagnostic"))
    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
    print("\tCompleted!")
    return

def get_tasks(task_names):
    task_names = task_names.split(',')
    if "all" in task_names:
        tasks = TASKS
    else:
        tasks = []
        for task_name in task_names:
            assert task_name in TASKS, "Task %s not found!" % task_name
            tasks.append(task_name)
    return tasks

def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
                        type=str, default='all')
    parser.add_argument('--path_to_mrpc',
                        help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
                        type=str, default='')
    args = parser.parse_args(arguments)

    if not os.path.isdir(args.data_dir):
        os.mkdir(args.data_dir)
    tasks = get_tasks(args.tasks)

    for task in tasks:
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        else:
            download_and_extract(task, args.data_dir)

if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))

2. 查看数据集

下载好后检查当前路径下,发现多了一个GLUE文件夹,点进去后,数据集内容如下图所示:
在这里插入图片描述
一般来说,现在本地都有GPU,如果没有GPU的话,在自行训练网络时,建议使用MRPC数据集,因为这个数据集比较小,训练起来相对容易,所需时间更短,使用CPU即可完成。
MRPC打开后格式如下:👇
在这里插入图片描述
该任务是判断两个字符串的语义是否相等。前面的1或者0表示相等与否,1表示相等,0表示不相等。随后给出的两个数字分别表示第一个句子的ID号,和第二个句子的ID号。随后的两个字符串分别是句子1和句子2。

附:

GLUE由纽约大学, 华盛顿大学, Google联合推出, 涵盖不同NLP任务类型, 截止至2020年1月其中包括11个子任务数据集, 成为衡量NLP研究发展的重要标准。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天使Di María

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值