原链接:Text classification with the torchtext library — PyTorch Tutorials 1.11.0+cu102 documentation
(1)导入数据集(经常会出现数据集下载失败的情况),有大佬的网盘:https://pan.baidu.com/s/1Rz_XoaTZWSRiHGOwkACosQ,提取码:j0no
下载完直接放到当前打开jupyter notebook的目录下,地址就到AG_NEWS.data文件夹即可
(现在的版本好像要加上root=‘地址’,不然会报错)
import torch
from torchtext.datasets import AG_NEWS
path = r'E:\Notebook\自然语言处\Text_classification_with_the_torchtext_library\AG_NEWS.data'
train_iter = iter(AG_NEWS(root=path, split='train'))
(2)构建词汇表
from torchtext.data.utils import get_tokenizer #导入分词工具
from torchtext.vocab import build_vocab_from_iterator #使用迭代器构建词表
tokenizer = get_tokenizer('basic_english') #创建分词器对象,采用英文分词
train_iter = AG_NEWS(root=path, split='train') #获取数据集,并生成迭代器
def yield_tokens(data_iter):
for _, text in data_iter: #获取每一条的标签label和内容text
yield tokenizer(text) #对获取内容分词,并返回。yield返回一个迭代器对象
#将未能识别的单词设置为<unk>
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
#设置<unk>的索引为默认索引,一旦遇到不能识别单词,转为<unk>的索引值
vocab.set_default_index(vocab['<unk>'])
(3)获取每条数据的label和text
text_pipeline = lambda x: vocab(tokenizer(x)) #获取每一条的text的索引表示
label_pipeline = lambda x: int(x) - 1 #获取对应的label
#演示
text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5297]
label_pipeline('10')
>>> 9
<