【Pytorch】【torchtext(三)】Dataset详解

一、概述

  1. 定义Field对象是明确如何处理不同类型的数据。而具体处理哪里的数据集,对不同的列采用不同的Field进行处理则是由Dataset对象来完成的。
  2. torchtext的Dataset对象继承自pytorch的Dataset对象,该对象提供了下载压缩数据并解压这些数据的方法。
  3. TabularDataset是torchtext内置的Dataset子类,其能够很方便的读取csv、json或tsv格式的文件。

二、使用TabluarDataset构建数据集

from torchtext.data import Field
from torchtext.data import TabularDataset
# 定义Field
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = Field(sequential=False, use_vocab=False)

# 构建Dataset
fields = [("id", None),("comment_text",TEXT),("toxic",LABEL)]
# 使用splits方法可以为多个数据集直接创建Dataset
train, valid = TabularDataset.splits(
    path='data',
    train='train_one_label.csv',
    validation='valid_one_label.csv',
    format='csv',
    skip_header=True,
    fields=fields)

test_datafields = [('id', None),('comment_text', TEXT)]

# 直接创建Dataset(不使用splits)
test = TabularDataset(
    path=r'data\test.csv',
    format='csv',
    skip_header=True,
    fields=test_datafields
)
print(train.fields)
print(train.examples[0].comment_text)
{'id': None, 'comment_text': <torchtext.data.field.Field object at 0x0000029BA6D245C0>, 'toxic': <torchtext.data.field.Field object at 0x0000029BA6D246A0>}
['explanation', 'why', 'the', 'edits', 'made', 'under', 'my', 'username', 'hardcore', 'metallica', 'fan', 'were', 'reverted?', 'they', "weren't", 'vandalisms,', 'just', 'closure', 'on', 'some', 'gas', 'after', 'i', 'voted', 'at', 'new', 'york', 'dolls', 'fac.', 'and', 'please', "don't", 'remove', 'the', 'template', 'from', 'the', 'talk', 'page', 'since', "i'm", 'retired', 'now.89.205.38.27']

三、自定义Dataset

1. Dataset的初始化方法

先了看一下原始Dataset的初始化方法Dataset(examples, fields, filter_pred=None),其中参数examples是一个包含对象torchtext.data.Example的列表,参数fieldsList(tuple(str, Field)),即[(列名,Field对象),(列名,Field对象),...]

2. TabularDataset的缺点

首先,没有进度条来显示进度;

其次,不能自动统计元数据(例如文本平均长度等);

最后,不能提供数据集的缓存功能;

3. 自定义Dataset解决TabularDataset的缺点

import os
import torch
import pandas as pd
from torchtext.data import Dataset,Example,Field
from tqdm import tqdm
tqdm.pandas(desc='pandas bar')

path='data'
train='train_one_label.csv'
validation='valid_one_label.csv'
test = 'test.csv'
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, fix_length=200)
LABEL = Field(sequential=False, use_vocab=False)
fields = [("id", None),("comment_text",TEXT),("toxic",LABEL)]
class MyDataset(Dataset):
    metadata = {} # 用来存储元数据
    
    def __init__(self, examples, fields, metadata, **kwargs):
        self.metadata = metadata
        super(MyDataset, self).__init__(examples, fields, **kwargs)
        
    @classmethod    
    def getInstance(cls, path, filename, fields):
        data = pd.read_csv(os.path.join(path, filename))
        metadata = cls._compute_metadata(data) # 计算元数据
        examples = list(data.progress_apply(lambda x:Example.fromCSV(list(x), fields),axis=1)) # 添加进度条
        return cls(examples,fields, metadata)
        
    @staticmethod    
    def _compute_metadata(data):
        """计算元数据"""
        metadata = {}
        metadata['columns'] = list(data.columns)
        metadata['avg_length'] = data['comment_text'].apply(lambda x: len(x.split())).mean()
        return metadata
        
    @staticmethod
    def save_cache(datasets, fields, datafiles, cachefile):
        """保存缓存数据"""
        examples = [dataset.examples for dataset in datasets]
        metadatas = [dataset.metadata for dataset in datasets]
        vocabs = {}
        reverse_fields = {}
        for name, field in fields.items():
            reverse_fields[field] = name
            
        for field, name in reverse_fields.items():
            if field is not None and hasattr(field, 'vocab'):
                vocabs[name] = field.vocab
                
        data = {'examples':examples, 'metadatas':metadatas, 'vocabs':vocabs, 'datafiles':datafiles}
        torch.save(data, cachefile)
    
    @staticmethod
    def load_cache(datafiles,cachefile, fields):
        """加载缓存数据"""
        cached_data = torch.load(cachefile)
        datasets = []
        for d in range(len(cached_data['datafiles'])):
            dataset = MyDataset(fields=fields,
                                examples=cached_data['examples'][d],
                                metadata=cached_data['metadatas'][d])
            datasets.append(dataset)
            
        for name, field in fields:
            if name in cached_data['vocabs']:
                field.vocab = cached_data['vocabs'][name]
        
        return datasets
            
             
    @classmethod
    def splits(cls,path,train=None,validation=None,test=None,fields=None,cache='cacheddata.pth'):
        fields_dict = dict(fields)
        datasets = None
        
        if cache:
            datafiles = list(f for f in (train, validation, test) if f is not None)
            datafiles = [os.path.expanduser(os.path.join(path, d)) for d in datafiles]
            cachefile = os.path.join(path, cache)
            if os.path.exists(cachefile):
                print("存在缓存数据,正在读取.")
                datasets = MyDataset.load_cache(datafiles, cachefile, fields)
        
        if not datasets:
            train_data = None if train is None else cls.getInstance(path, train, fields)
            valid_data = None if validation is None else cls.getInstance(path, validation, fields)
            test_data = None if test is None else cls.getInstance(path, test, fields)
            datasets = tuple(d for d in (train_data, valid_data, test_data) if d is not None)
            
            fields_set = set(fields_dict.values())
            for field in fields_set:
                if field is not None and field.use_vocab:
                    field.build_vocab(*datasets)
            print("正在保存缓存数据.")
            MyDataset.save_cache(datasets, fields_dict, datafiles, cachefile)
            
        return datasets
train_data, valid_data, test_data = MyDataset.splits(path, train, validation, test, fields)
存在缓存数据,正在读取.
Pytorch中遍历dataset可以使用torch.utils.data.DataLoader这个类。在初始化DataLoader时,一般常用的参数有dataset、batch_size、shuffle和num_workers等。其中dataset就是我们构建的自定义dataset类。在使用时,可以直接使用for循环来遍历dataloader对象,并且可以通过迭代器的方式输出每个batch的数据。具体实现如下: ```python import torch from torch.utils.data import DataLoader # 创建自定义dataset对象 dataset = MyDataset() # 创建dataloader对象,并指定batch_size和是否进行数据打乱 dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # 遍历dataloader对象 for batch_data in dataloader: # 处理每个batch的数据 inputs, labels = batch_data # 进行模型的训练或预测等操作 ... ``` 在遍历dataloader时,实际上是从dataset中取出数据,只是在取数据的规则上进行了一些修改,比如可以进行数据的打乱操作。因此,在遍历dataloader时,会调用自己定义的dataset类中的__getitem__()方法来获取数据。通过这种方式,我们可以方便地对数据进行mini-batch的训练。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [使用Pytorch中的Dataset类构建数据集的方法及其底层逻辑](https://blog.csdn.net/rowevine/article/details/123631144)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* [对pytorch中的dataset和dataloader的一些理解](https://blog.csdn.net/weixin_45700881/article/details/128351086)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BQW_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值