一、概述
1. torchtext中的主要组件
torchtext主要包含的组件有:Field、Dataset和Iterator。
1.1 Field
Field是用于处理数据的对象,处理的过程通过参数指定,且通过Filed能够参数Example对象。下面是定义Field对象的例子,
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
1.2 Dataset
继承自pytorch的Dataset,表示数据集。Dataset可以看做是Example的实例集合;
1.3 Iterator
Iterator是torchtext到模型的输出,它提供了对数据的一般处理方式,比如打乱,排序,等等,可以动态修改batch大小
二、Quick Start
import pandas as pd
import torch
from torchtext import data
from torchtext.vocab import Vectors
from torchtext.data import TabularDataset,Dataset,BucketIterator,Iterator
from torch.nn import init
from tqdm import tqdm
2.1 展示数据格式
df1 = pd.read_csv('./data/train_one_label.csv').head()
df2 = pd.read_csv('./data/test.csv').head()
display(df1)
display(df2)
id | comment_text | toxic | |
---|---|---|---|
0 | 0000997932d777bf | Explanation\nWhy the edits made under my usern... | 0 |
1 | 000103f0d9cfb60f | D'aww! He matches this background colour I'm s... | 0 |
2 | 000113f07ec002fd | Hey man, I'm really not trying to edit war. It... | 0 |
3 | 0001b41b1c6bb37e | "\nMore\nI can't make any real suggestions on ... | 0 |
4 | 0001d958c54c6e35 | You, sir, are my hero. Any chance you remember... | 0 |
id | comment_text | |
---|---|---|
0 | 00001cee341fdb12 | Yo bitch Ja Rule is more succesful then you'll... |
1 | 0000247867823ef7 | == From RfC == \n\n The title is fine as it is... |
2 | 00013b17ad220c46 | " \n\n == Sources == \n\n * Zawe Ashton on Lap... |
3 | 00017563c3f7919a | :If you have a look back at the source, the in... |
4 | 00017695ad8997eb | I don't anonymously edit articles at all. |
2.2 定义Filed
tokenize = lambda x: x.split() # tokenize指定如何划分句子
# 定义了两种Filed,分别用于处理文本和标签
TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = data.Field(sequential=False, use_vocab=False)
2.3 构建Dataset
fields = [("id", None),("comment_text",TEXT),("toxic",LABEL)] # 列名与对应的Field对象
# TabularDataset:从csv、tsv、json的文件中读取数据并生成dataset
train, valid = TabularDataset.splits(
path='data',
train='train_one_label.csv',
validation='valid_one_label.csv',
format='csv',
skip_header=True,
fields=fields)
test_datafields = [('id', None),('comment_text', TEXT)]
test = TabularDataset(
path=r'data\test.csv',
format='csv',
skip_header=True,
fields=test_datafields
)
print(type(train))
<class 'torchtext.data.dataset.TabularDataset'>
构建词表
TEXT.build_vocab(train,valid,test,vectors='glove.6B.100d')
print(TEXT.vocab.stoi['<pad>'])
print(TEXT.vocab.stoi['<unk>'])
print(TEXT.vocab.itos[0])
print(TEXT.vocab.freqs.most_common(5))
print(vars(train.examples[0]))
1
0
<unk>
[('the', 226), ('to', 137), ('a', 90), ('is', 84), ('you', 82)]
{'comment_text': ['explanation', 'why', 'the', 'edits', 'made', 'under', 'my', 'username', 'hardcore', 'metallica', 'fan', 'were', 'reverted?', 'they', "weren't", 'vandalisms,', 'just', 'closure', 'on', 'some', 'gas', 'after', 'i', 'voted', 'at', 'new', 'york', 'dolls', 'fac.', 'and', 'please', "don't", 'remove', 'the', 'template', 'from', 'the', 'talk', 'page', 'since', "i'm", 'retired', 'now.89.205.38.27'], 'toxic': '0'}
2.4 生成迭代器
train_iter, valid_iter = BucketIterator.splits(
(train, valid),
batch_sizes=(8, 8),
device="cpu",
sort_key=lambda x: len(x.comment_text),
sort_within_batch=False,
repeat=False
)
test_iter = Iterator(test, batch_size=8, device="cpu", sort=False, sort_within_batch=False, repeat=False)
调用迭代器
for idx, batch in enumerate(train_iter):
print(batch)
print(batch.__dict__.keys())
text, label = batch.comment_text, batch.toxic
print(text.shape, label.shape)
[torchtext.data.batch.Batch of size 1]
[.comment_text]:[torch.LongTensor of size 200x1]
[.toxic]:[torch.LongTensor of size 1]
dict_keys(['batch_size', 'dataset', 'fields', 'input_fields', 'target_fields', 'comment_text', 'toxic'])
torch.Size([200, 1]) torch.Size([1])
[torchtext.data.batch.Batch of size 8]
[.comment_text]:[torch.LongTensor of size 200x8]
[.toxic]:[torch.LongTensor of size 8]
dict_keys(['batch_size', 'dataset', 'fields', 'input_fields', 'target_fields', 'comment_text', 'toxic'])
torch.Size([200, 8]) torch.Size([8])
[torchtext.data.batch.Batch of size 8]
[.comment_text]:[torch.LongTensor of size 200x8]
[.toxic]:[torch.LongTensor of size 8]
dict_keys(['batch_size', 'dataset', 'fields', 'input_fields', 'target_fields', 'comment_text', 'toxic'])
torch.Size([200, 8]) torch.Size([8])
[torchtext.data.batch.Batch of size 8]
[.comment_text]:[torch.LongTensor of size 200x8]
[.toxic]:[torch.LongTensor of size 8]
dict_keys(['batch_size', 'dataset', 'fields', 'input_fields', 'target_fields', 'comment_text', 'toxic'])
torch.Size([200, 8]) torch.Size([8])