利用Torchtext处理文本数据
参考资料:官方文档: pytorch.
torchtext
torchtext主要包含三部分,分别是Field、Dataset和Iterator三部分,每部分的作用如下:
- Field:配置对数据的预处理信息,比如指定分词方法、是否是序列、指定起始字符和结束字符等等。
- Dataset:用于加载数据集。
- Iterator:输出的迭代器,来把数据切分成batch_size来提供给模型作为输入。
接下来具体介绍方法应用,以及分别实战操作处理翻译和分类的数据。
Field
首先定义field,相当于声明式处理数据,来规定数据的结构应该是如何。一些重要的参数如下:
- sequential:是否把数据表示成序列,如果是False, 不能使用分词 默认值: True.
- init_token: 每一条数据的起始字符 默认值: None.
- eos_token: 每条数据的结尾字符 默认值: None.
- tokenize: 分词函数. 默认值: str.split.
- lower: 是否把数据转化为小写 默认值: False.
classification
首先,对于中文的序列数据,我们自定义一个中文的分词器,内在方法采用jieba分词
def word_cut(text):
return [word for word in jieba.cut(text) if word.strip()] #采用的jieba分词
接着预定义数据处理方法:
LABEL = data.Field(sequential=False, unk_token=None)
TEXT = data.Field(sequential=True, tokenize=word_cut)
在这里因为是分类问题,所以定义两个字段标签和文本,其中sequential定义是否是序列数据。注意在这里,需要定义unk_token为None,否则的话在最后的标签集中会多一个未知标签。
translation
同样的先定义分词方法,对于中文依旧采用jieba分词,而对于英文采用通用的spacy工具包的方法
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
def word_cut(text):
text = regex.sub(' ', text)
return [word for word in jieba.cut(text) if word.strip()]
在翻译数据的字段预处理定义中,需要设定起始标识与末尾标识init_token和eos_token。
SRC = Field(
tokenize=tokenize_en,
init_token='<sos>',
eos_token='<eos>',
lower=True)
TRG = Field(
tokenize=word_cut,
init_token='<sos>',
eos_token='<eos>',
lower=True)
Dataset
torchtext含有许多自定义的数据可以直接下载应用,而对于表格数据数据采用TabularDataset,对于翻译数据的加载采用TranslationDataset方法。
classification
train, val = data.TabularDataset.splits(
path=args.path, format='tsv', skip_header=True,
train='train.tsv', validation='valid.tsv',
fields=[
('label', label_field),
('text', text_field)
]
)
- path:数据文件的路径参数
- format:文件的格式,tsv文件:表格形式,一行中的数据以空格间隔,csv文件:表格形式,一行中的数据以逗号间隔
- skip_header:是否忽略第一行,当表格数据第一行是标题时,可以设置为True,默认值是False
- train、validat、test:训练集、验证集和测试集对应的文件名称
- fields:给表格中的数据字段定义预处理,在这里第一列是标签所以定义label预处理,而第二列是文本所以定义text预处理。注意:这里必须按照表格里的列顺序定义数据!
translation
train, val, test = TranslationDataset.splits(path=r'./data',
exts=('.en', '.zh'),
fields=(src_field, trg_field))
首先对于翻译的数据准备应当是将成对的翻译数据分别存为两个文本文件,一个是源文本,一个是目标文本,然后它们的行数应当是一致的。最后将它们的文件后缀改为语言的简称。
- path:所有数据的父目录
- exts:规定源文本和目标文本
- field:对于文本的预处理定义
BucketIterator
在生成数据后,将生成vocab词表
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)
Bucketlterator的功能是将文本长度相似的数据作为一个batch,通过迭代器生成batch数据来喂给模型作为输入。
classification
train_iter, val_iter = BucketIterator.splits(
(train_dataset, val_dataset),
batch_sizes=(args.batch_size, len(val_dataset)),
sort_key=lambda x: len(x.text),
device=-1,
**kwargs)
- batch_size:设定迭代数据大小
- sort_key:设定排序函数
- device:表示数据在何种设备运行,-1指代CPU
translation
train_iter, val_iter, test_iter = BucketIterator.splits(
(train_data, val_data, test_data),
batch_size=args.batch_size,
device=args.device,
)
参数设定如上
通过以上步骤就得到了我们可以以直接输入到模型里的数据