TorchText
最近开始使用PyTorch进行NLP神经网络模型的搭建,发现了torchtext这一文本处理神器,可以方便的对文本进行预处理,例如截断补长、构建词表等。但是因为nlp的热度远不如cv,对于torchtext介绍的相关博客数量也远不如torchvision。在使用过程中主要参考了A Comprehensive Introduction to Torchtext和Language modeling tutorial in torchtext这两篇博客和torchtext官方文档,对于torchtext的基本用法有了大致的了解。在我的另一篇博客:PyTorch在NLP任务中使用预训练词向量中也涉及了一点torchtext的用法。在以上两篇博客的基础上,本文对torchtext的使用做一个概括性的总结,更复杂高级的用法仍然推荐大家阅读官方文档。本文所涉及的完整可运行代码见:https://github.com/atnlp/torchtext-summary
torchtext概述
torchtext对数据的处理可以概括为Field,Dataset和迭代器这三部分。
Field对象
Field对象指定要如何处理某个字段.
Dataset
Dataset定义数据源信息.
迭代器
迭代器返回模型所需要的处理后的数据.迭代器主要分为Iterator, BucketIerator, BPTTIterator三种。
- Iterator:标准迭代器
- BucketIerator:相比于标准迭代器,会将类似长度的样本当做一批来处理,因为在文本处理中经常会需要将每一批样本长度补齐为当前批中最长序列的长度,因此当样本长度差别较大时,使用BucketIerator可以带来填充效率的提高。除此之外,我们还可以在Field中通过fix_length参数来对样本进行截断补齐操作。
- BPTTIterator: 基于BPTT(基于时间的反向传播算法)的迭代器,一般用于语言模型中。
使用Dataset类
实验数据集仍然使用A Comprehensive Introduction to Torchtext中使用的小批量数据集,为了简化代码,只保留了toxic这一个标签列。
- 查看数据集
- 导入torchtext相关包
from torchtext import data
from torchtext.vocab import Vectors
from torch.nn import init
from tqdm import tqdm
- 构建Field对象
tokenize = lambda x: x.split()
# fix_length指定了每条文本的长度,截断补长
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = data.Field(sequential=False, use_vocab=False)
- 使用torchtext内置的Dataset构建数据集
torchtext预置的Dataset类的API如下,我们必须至少传入examples和fields这两个参数。examples为由torchtext中的Example对象构造的列表,Example为对数据集中一条数据的抽象。fields可简单理解为每一列数据和Field对象的绑定关系,在下面的代码中将分别用train_examples和test_examples来构建训练集和测试集的examples对象,train_fields和test_fields数据集的fields对象。
class torchtext.data.Dataset(examples, fields, filter_pred=None)
# 读取数据
train_data = pd.read_csv('data/train_one_label.csv')
valid_data = pd.read_csv('data/valid_one_label.csv')
test_data = pd.read_csv("data/test.csv")
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True)
LABEL = data.Field(sequential=False, use_vocab=False)
# get_dataset构造并返回Dataset所需的examples和fields
def get_dataset(csv_data, text_field, label_field, test=False):
# id数据对训练在训练过程中没用,使用None指定其对应的field
fields = [("id", None), # we won't be needing the id, so we pass in None as the field
("comment_text", text_field), ("toxic", label_field)]
examples = []
if test:
# 如果为测试集,则不加载label
for text in tqdm(csv_data['comment_text']):
examples.append(data.Example.fromlist([None, text, None], fields))
else:
for text, label in tqdm(zip(csv_data['comment_text'], csv_data[