解决BERT官方模型script脚本下载GLUE data失败的问题
1. 运行脚本下载GLUE数据集(建议科学上网)
BERT的提出对于NLP领域是一项里程碑。再一次证明了利用语言模型的预训练然后再针对具体下游任务微调是NLP研究近些年来的优秀做法。
在github上下载BERT官方模型需要下载的BLUE data数据集时,可能失败,因为可能由于模型开发时python版本较低,导致导包出现问题,下面给出python3以上版本下载GLUE data数据集的代码👇
import io
import os
import sys
import shutil
import argparse
import tempfile
import urllib
import urllib.request
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic", "MRPC"]
TASK2PATH = {"CoLA": 'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',
"SST": 'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
"QQP": 'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',
"STS": 'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',
"MNLI": 'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',
"QNLI": 'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',
"RTE": 'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',
"WNLI": 'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',
"diagnostic": 'https://dl.fbaipublicfiles.com/glue/data/AX.tsv',
"MRPC": "https://raw.githubusercontent.com/MegEngine/Models/master/official/nlp/bert/glue_data/MRPC/dev_ids.tsv"}
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task)
if task == "MNLI":
print(
"\tNote (12/10/20): This script no longer downloads SNLI. You will need to manually download and format the data to use SNLI.")
data_file = "%s.zip" % task
urllib.request.urlretrieve(TASK2PATH[task], data_file)
with zipfile.ZipFile(data_file) as zip_ref:
zip_ref.extractall(data_dir)
os.remove(data_file)
print("\tCompleted!")
def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
if not os.path.isdir(mrpc_dir):
os.mkdir(mrpc_dir)
if path_to_data:
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
else:
try:
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
except urllib.error.HTTPError:
print("Error downloading MRPC")
return
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \
io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
try:
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
except KeyError or urllib.error.HTTPError:
print("\tError downloading standard development IDs for MRPC. You will need to manually split your data.")
return
dev_ids = []
with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \
io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \
io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t')
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
print("\tCompleted!")
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
os.mkdir(os.path.join(data_dir, "diagnostic"))
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
print("\tCompleted!")
return
def get_tasks(task_names):
task_names = task_names.split(',')
if "all" in task_names:
tasks = TASKS
else:
tasks = []
for task_name in task_names:
assert task_name in TASKS, "Task %s not found!" % task_name
tasks.append(task_name)
return tasks
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc',
help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
os.mkdir(args.data_dir)
tasks = get_tasks(args.tasks)
for task in tasks:
if task == 'MRPC':
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic':
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
2. 查看数据集
下载好后检查当前路径下,发现多了一个GLUE文件夹,点进去后,数据集内容如下图所示:
一般来说,现在本地都有GPU,如果没有GPU的话,在自行训练网络时,建议使用MRPC数据集,因为这个数据集比较小,训练起来相对容易,所需时间更短,使用CPU即可完成。
MRPC打开后格式如下:👇
该任务是判断两个字符串的语义是否相等。前面的1或者0表示相等与否,1表示相等,0表示不相等。随后给出的两个数字分别表示第一个句子的ID号,和第二个句子的ID号。随后的两个字符串分别是句子1和句子2。
附:
GLUE由纽约大学, 华盛顿大学, Google联合推出, 涵盖不同NLP任务类型, 截止至2020年1月其中包括11个子任务数据集, 成为衡量NLP研究发展的重要标准。