BERT微调数据集
自然语言推断任务:
主要研究 假设
(hypothesis)是否可以从前提
(premise)中推断出来, 其中两者都是文本序列。 换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:
- 蕴涵(entailment):假设可以从前提中推断出来。
- 矛盾(contradiction):假设的否定可以从前提中推断出来。
- 中性(neutral):所有其他情况。
斯坦福自然语言推断(SNLI)数据集
由500000多个带标签的英语句子对
组成的集合
import os
import re
import torch
from torch import nn
from d2l import torch as d2l
#@save
d2l.DATA_HUB['SNLI'] = (
'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
'9fcde07509c7e87ec61c640c1b2753d9041758e4')
data_dir = "D:\environment\data\data\snli_1.0"
读取数据集
#@save
def read_snli(data_dir, is_train):
"""将SNLI数据集解析为前提、假设和标签"""
def extract_text(s):
# 删除我们不会使用的信息
s = re.sub('\\(', '', s)
s = re.sub('\\)', '', s)
# 用一个空格替换两个或多个连续的空格
s = re.sub('\\s{2,}', ' ', s)
return s.strip()
label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
if is_train else 'snli_1.0_test.txt')
with open(file_name, encoding = 'utf-8') as f:
rows = [row.split('\t') for row in f.readlines()[1:]]
premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
hypotheses = [extract_text(row[2]) for row in rows if row[0] \
in label_set]
labels = [label_set[row[0]] for row in rows if row[0] in label_set]
return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
print('前提:', x0)
print('假设:', x1)
print('标签:', y)
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2
前提: A person on a horse jumps over a broken down airplane .
假设: A person is at a diner , ordering an omelette .
标签: 1
前提: A person on a horse jumps over a broken down airplane .
假设: A person is outdoors , on a horse .
标签: 0
统计三个关系的数量
test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]
上面训练集,下面测试集,挺平均的
data[2]是标签,统计标签数量就行
#@save
class SNLIDataset(torch.utils.data.Dataset):
"""用于加载SNLI数据集的自定义数据集"""
def __init__(self, dataset, num_steps, vocab=None):
self.num_steps = num_steps
all_premise_tokens = d2l.tokenize(dataset[0])
all_hypothesis_tokens = d2l.tokenize(dataset[1])
if vocab is None:
self.vocab = d2l.Vocab(all_premise_tokens + \
all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
else:
self.vocab = vocab
self.premises = self._pad(all_premise_tokens)
self.hypotheses = self._pad(all_hypothesis_tokens)
self.labels = torch.tensor(dataset[2])
print('read ' + str(len(self.premises)) + ' examples')
def _pad(self, lines):
return torch.tensor([d2l.truncate_pad(
self.vocab[line], self.num_steps, self.vocab['<pad>'])
for line in lines])
def __getitem__(self, idx):
return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]
def __len__(self):
return len(self.premises)
vocab需要与BERT预训练时的vocab保持一致,不然他不认识呀,所以下载与训练模型的时候一般都是下载模型和vocab
整理一下
#@save
def load_data_snli(batch_size, num_steps=50):
"""下载SNLI数据集并返回数据迭代器和词表"""
num_workers = d2l.get_dataloader_workers()
data_dir = "D:\environment\data\data\snli_1.0"
train_data = read_snli(data_dir, True)
test_data = read_snli(data_dir, False)
train_set = SNLIDataset(train_data, num_steps)
test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size,
shuffle=True)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
shuffle=False)
return train_iter, test_iter, train_set.vocab
看看大小
train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples
18678
for X, Y in train_iter:
print(X[0].shape)
print(X[1].shape)
print(Y.shape)
break
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])
batchsize = 128
一个句子长度为50
最后说一下这一节的踩坑:
-
和之前一节的数据加载一样
torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers = num_workers )
num_workers = num_workers,开多线程读取数据会报错
-
下载解压数据集时报错
OSError: [Errno 22] Invalid argument: '..\\data\\snli_1.0\\Icon\r
报错的原因:SNLI数据集的压缩文件"snli_1.0.zip"里面有两个路径为“snli_1.0\Icon\r”和“’__MACOSX/snli_1.0/._Icon\r’”的文件,导致无法解析此路径进而导致整个文件无法解压。
解决方法:手动解压之后把data_dir赋值为数据集解压后的路径
data_dir = d2l.download_extract('SNLI')
改成
data_dir = "D:\environment\data\data\snli_1.0"