代码
import os
import re
import torch
from torch import nn
from d2l import torch as d2l
#@save
d2l.DATA_HUB['SNLI'] = (
'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
'9fcde07509c7e87ec61c640c1b2753d9041758e4')
data_dir = d2l.download_extract('SNLI')
#@save
def read_snli(data_dir, is_train):
"""将SNLI数据集解析为前提、假设和标签"""
def extract_text(s):
# 删除我们不会使用的信息
s = re.sub('\\(', '', s)
s = re.sub('\\)', '', s)
# 用一个空格替换两个或多个连续的空格
s = re.sub('\\s{2,}', ' ', s)
return s.strip()
label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
if is_train else 'snli_1.0_test.txt')
with open(file_name, 'r') as f:
rows = [row.split('\t') for row in f.readlines()[1:]]
premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
hypotheses = [extract_text(row[2]) for row in rows if row[0] \
in label_set]
labels = [label_set[row[0]] for row in rows if row[0] in label_set]
return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
print('前提:', x0)
print('假设:', x1)
print('标签:', y)
test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
print([[row for row in data[2]].count(i) for i in range(3)])
#@save
class SNLIDataset(torch.utils.data.Dataset):
"""用于加载SNLI数据集的自定义数据集"""
def __init__(self, dataset, num_steps, vocab=None):
self.num_steps = num_steps
all_premise_tokens = d2l.tokenize(dataset[0])
all_hypothesis_tokens = d2l.tokenize(dataset[1])
if vocab is None:
self.vocab = d2l.Vocab(all_premise_tokens + \
all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
else:
self.vocab = vocab
self.premises = self._pad(all_premise_tokens)
self.hypotheses = self._pad(all_hypothesis_tokens)
self.labels = torch.tensor(dataset[2])
print('read ' + str(len(self.premises)) + ' examples')
def _pad(self, lines):
return torch.tensor([d2l.truncate_pad(
self.vocab[line], self.num_steps, self.vocab['<pad>'])
for line in lines])
def __getitem__(self, idx):
return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]
def __len__(self):
return len(self.premises)
#@save
def load_data_snli(batch_size, num_steps=50):
"""下载SNLI数据集并返回数据迭代器和词表"""
num_workers = d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_data = read_snli(data_dir, True)
test_data = read_snli(data_dir, False)
train_set = SNLIDataset(train_data, num_steps)
test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size,
shuffle=True,
num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
shuffle=False,
num_workers=num_workers)
return train_iter, test_iter, train_set.vocab
train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
for X, Y in train_iter:
print(X[0].shape)
print(X[1].shape)
print(Y.shape)
break
解析
这段代码是用于加载和处理Stanford Natural Language Inference (SNLI) 数据集的一个Python脚本,使用PyTorch框架。以下是对代码进行逐段的中文解释:
1. 导入所需的模块和库。
import os
import re
import torch
from torch import nn
from d2l import torch as d2l
2. 定义数据集的链接和哈希值,利用d2l的`DATA_HUB`工具下载和提取数据。
d2l.DATA_HUB['SNLI'] = (
'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
'9fcde07509c7e87ec61c640c1b2753d9041758e4')
data_dir = d2l.download_extract('SNLI')
3. 实现`read_snli`函数,用于读取SNLI数据集,并将其解析为前提(premises)、假设(hypotheses)以及对应的标签(labels)。
def read_snli(data_dir, is_train):
...
4. 加载训练数据集,并显示前三个例子的前提、假设和标签。
train_data = read_snli(data_dir, is_train=True)
...
5. 加载测试数据集,并统计每个标签(蕴含、矛盾、中立)的数量。
test_data = read_snli(data_dir, is_train=False)
...
6. 定义`SNLIDataset`类,该类为PyTorch的自定义数据集类,用于处理和填充SNLI数据集。
class SNLIDataset(torch.utils.data.Dataset):
...
7. 定义`load_data_snli`函数,它使用上面定义的`SNLIDataset`类来创建训练和测试数据迭代器,同时返回构建的词汇表(vocab)。
def load_data_snli(batch_size, num_steps=50):
...
8. 创建训练和测试迭代器,并获取词汇表的长度。
train_iter, test_iter, vocab = load_data_snli(128, 50)
9. 遍历训练迭代器中的第一个batch,打印出前提、假设的Tensor形状以及标签的Tensor形状。
for X, Y in train_iter:
...
break
值得注意的是,这个代码的部分注释标签是`#@save`,这是d2l包中用于标记某个代码块或函数将会被保存并在后续使用中重新加载的一个特殊注释。
这段代码与自然语言推断(Natural Language Inference,NLI)任务相关,它是自然语言处理(Natural Language Processing,NLP)的一个子领域。NLI 的目标是确定一对句子的关系,即前提(premise)和假设(hypothesis)之间是否存在蕴含、矛盾或者无关的关系。这段代码特别处理了 Stanford Natural Language Inference (SNLI) 数据集,该数据集是用于训练和评估 NLI 模型的常用数据集。
现在,我将逐步解析这段代码并以中文解释其功能:
1. 首先,代码通过从网址下载并解压缩 SNLI 数据集。
2. 定义了 read_snli 函数,该函数的目的是从下载的文件中读取数据。数据会被分为三部分:前提、假设和相关的标签(蕴含、矛盾或中立)。
3. 接下来,定义了 SNLIDataset 类,一个用于处理 SNLI 数据的 PyTorch Dataset。该类会对文本数据进行分词,构建一个词汇表,并将文本转换成适用于机器学习模型的格式。文本数据会被截断或补齐到一定长度。
4. 而后,定义了 load_data_snli 函数,它调用 read_snli 函数和 SNLIDataset 类来准备数据迭代器,提供了一种方式来批量获取数据。
5. 最后,代码展示了如何使用这些函数和类来读取数据和词汇表,以及如何从 train_iter 中获取一个批量的数据。
在代码中会看到:
- train_data 和 test_data 是从 SNLI 数据集中分别读取的训练和测试数据。
- 其中每一个 data 是三元组,包含前提、假设和标签的列表。
- 使用 zip 函数来打印前三个数据点的前提、假设和标签,以及标签在整个数据集的分布。
- 创建了 SNLIDataset 实例,用于训练和测试数据集,并通过 DataLoader 实现了批量数据迭代。
- vocab 是从训练集构建的词汇表,使用 len(vocab) 可以获取词汇表的大小。
- 最后,可以通过 train_iter 迭代器获取每批数据,并打印出前提、假设和标签的张量尺寸。
需要注意的是,运行这段代码前需要确保具备访问 Stanford 的网站权限并安装了相关的 Python 包,比如 PyTorch 和 d2l 包。d2l 包是用于深度学习的一个辅助库,提供了一些常用的函数和类。