一、概述
- 定义Field对象是明确如何处理不同类型的数据。而具体处理哪里的数据集,对不同的列采用不同的Field进行处理则是由Dataset对象来完成的。
- torchtext的Dataset对象继承自pytorch的Dataset对象,该对象提供了下载压缩数据并解压这些数据的方法。
- TabularDataset是torchtext内置的Dataset子类,其能够很方便的读取csv、json或tsv格式的文件。
二、使用TabluarDataset构建数据集
from torchtext.data import Field
from torchtext.data import TabularDataset
# 定义Field
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = Field(sequential=False, use_vocab=False)
# 构建Dataset
fields = [("id", None),("comment_text",TEXT),("toxic",LABEL)]
# 使用splits方法可以为多个数据集直接创建Dataset
train, valid = TabularDataset.splits(
path='data',
train='train_one_label.csv',
validation='valid_one_label.csv',
format='csv',
skip_header=True,
fields=fields)
test_datafields = [('id', None),('comment_text', TEXT)]
# 直接创建Dataset(不使用splits)
test = TabularDataset(
path=r'data\test.csv',
format='csv',
skip_header=True,
fields=test_datafields
)
print(train.fields)
print(train.examples[0].comment_text)
{'id': None, 'comment_text': <torchtext.data.field.Field object at 0x0000029BA6D245C0>, 'toxic': <torchtext.data.field.Field object at 0x0000029BA6D246A0>}
['explanation', 'why', 'the', 'edits', 'made', 'under', 'my', 'username', 'hardcore', 'metallica', 'fan', 'were', 'reverted?', 'they', "weren't", 'vandalisms,', 'just', 'closure', 'on', 'some', 'gas', 'after', 'i', 'voted', 'at', 'new', 'york', 'dolls', 'fac.', 'and', 'please', "don't", 'remove', 'the', 'template', 'from', 'the', 'talk', 'page', 'since', "i'm", 'retired', 'now.89.205.38.27']
三、自定义Dataset
1. Dataset的初始化方法
先了看一下原始Dataset的初始化方法Dataset(examples, fields, filter_pred=None)
,其中参数examples
是一个包含对象torchtext.data.Example
的列表,参数fields
是List(tuple(str, Field))
,即[(列名,Field对象),(列名,Field对象),...]
。
2. TabularDataset的缺点
首先,没有进度条来显示进度;
其次,不能自动统计元数据(例如文本平均长度等);
最后,不能提供数据集的缓存功能;
3. 自定义Dataset解决TabularDataset的缺点
import os
import torch
import pandas as pd
from torchtext.data import Dataset,Example,Field
from tqdm import tqdm
tqdm.pandas(desc='pandas bar')
path='data'
train='train_one_label.csv'
validation='valid_one_label.csv'
test = 'test.csv'
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = Field(sequential=False, use_vocab=False)
fields = [("id", None),("comment_text",TEXT),("toxic",LABEL)]
class MyDataset(Dataset):
metadata = {} # 用来存储元数据
def __init__(self, examples, fields, metadata, **kwargs):
self.metadata = metadata
super(MyDataset, self).__init__(examples, fields, **kwargs)
@classmethod
def getInstance(cls, path, filename, fields):
data = pd.read_csv(os.path.join(path, filename))
metadata = cls._compute_metadata(data) # 计算元数据
examples = list(data.progress_apply(lambda x:Example.fromCSV(list(x), fields),axis=1)) # 添加进度条
return cls(examples,fields, metadata)
@staticmethod
def _compute_metadata(data):
"""计算元数据"""
metadata = {}
metadata['columns'] = list(data.columns)
metadata['avg_length'] = data['comment_text'].apply(lambda x: len(x.split())).mean()
return metadata
@staticmethod
def save_cache(datasets, fields, datafiles, cachefile):
"""保存缓存数据"""
examples = [dataset.examples for dataset in datasets]
metadatas = [dataset.metadata for dataset in datasets]
vocabs = {}
reverse_fields = {}
for name, field in fields.items():
reverse_fields[field] = name
for field, name in reverse_fields.items():
if field is not None and hasattr(field, 'vocab'):
vocabs[name] = field.vocab
data = {'examples':examples, 'metadatas':metadatas, 'vocabs':vocabs, 'datafiles':datafiles}
torch.save(data, cachefile)
@staticmethod
def load_cache(datafiles,cachefile, fields):
"""加载缓存数据"""
cached_data = torch.load(cachefile)
datasets = []
for d in range(len(cached_data['datafiles'])):
dataset = MyDataset(fields=fields,
examples=cached_data['examples'][d],
metadata=cached_data['metadatas'][d])
datasets.append(dataset)
for name, field in fields:
if name in cached_data['vocabs']:
field.vocab = cached_data['vocabs'][name]
return datasets
@classmethod
def splits(cls,path,train=None,validation=None,test=None,fields=None,cache='cacheddata.pth'):
fields_dict = dict(fields)
datasets = None
if cache:
datafiles = list(f for f in (train, validation, test) if f is not None)
datafiles = [os.path.expanduser(os.path.join(path, d)) for d in datafiles]
cachefile = os.path.join(path, cache)
if os.path.exists(cachefile):
print("存在缓存数据,正在读取.")
datasets = MyDataset.load_cache(datafiles, cachefile, fields)
if not datasets:
train_data = None if train is None else cls.getInstance(path, train, fields)
valid_data = None if validation is None else cls.getInstance(path, validation, fields)
test_data = None if test is None else cls.getInstance(path, test, fields)
datasets = tuple(d for d in (train_data, valid_data, test_data) if d is not None)
fields_set = set(fields_dict.values())
for field in fields_set:
if field is not None and field.use_vocab:
field.build_vocab(*datasets)
print("正在保存缓存数据.")
MyDataset.save_cache(datasets, fields_dict, datafiles, cachefile)
return datasets
train_data, valid_data, test_data = MyDataset.splits(path, train, validation, test, fields)
存在缓存数据,正在读取.