torch.dataset的构建

1.数据转换类(TranslateData)

import random
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import sys
sys.path.append('.')
from vocabField import VocabField
# translate_data在dataset中使用
# collate_fn在dataloader中使用
class TranslateData():
    def __init__(self,pad = 0):
        self.pad = pad
    def collate_fn(self,batch):
        src = list(map(lambda x:x['src'],batch))
        tgt = list(map(lambda x:x['tgt'],batch))
        src_len = list(map(lambda x:x['src_len'],batch))
        tgt_len = list(map(lambda x:x['tgt_len'],batch))
        src = torch.transpose(pad_sequence(src,padding_value = self.pad),0,1)
        tgt = torch.transpose(pad_sequence(tgt,padding_value = self.pad),0,1)
        src_len = torch.stack(src_len)
        tgt_len = torch.stack(tgt_len)
        return {'src':src,'tgt':tgt,'src_len':src_len,'tgt_len':tgt_len}
    def translate_data(self,subs,obj):
        import re
        import unicodedata
        def unicodeToAscii(s):
            return ''.join(
                c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c) != 'Mn'
            )
        def normalizeString(s):
            s = unicodeToAscii(s.lower().strip())
            s = re.sub(r'(.!?)',r'\1',s)
            s = re.sub(r'[^a-zA-Z.!?]+',r' ',s)
            return s
        src,tgt = subs
        src = normalizeString(src).split(' ')
        tgt = normalizeString(tgt).split(' ')
        tgt = [obj.tgt_vocab.sos_token] + tgt + [obj.tgt_vocab.eos_token]
        if len(src) > obj.max_src_length or len(tgt) > obj.max_tgt_length:
            return None
        src_length,tgt_length = len(src),len(tgt)
        src_ids = [obj.src_vocab.word2idx[w] for w in src]
        tgt_ids = [obj.tgt_vocab.word2idx[w] for w in tgt]
        return {
            'src':torch.LongTensor(src_ids),
            'tgt':torch.LongTensor(tgt_ids),
            'src_len':torch.LongTensor([src_length]),
            'tgt_len':torch.LongTensor([tgt_length])}

2.Dataset类(DialogDataset)

class DialogDataset(Dataset):
    def __init__(self,data_fp,transform_fuc,src_vocab,tgt_vocab,max_src_length,max_tgt_length):
        self.datasets = []
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_src_length = max_src_length
        self.max_tgt_length = max_tgt_length
        
        loaded = 0
        data_monitor = 0
        with open(data_fp,'r') as f:
            for line in tqdm(f,desc = 'Load Data:'):
                subs = line.strip().split('\t')
                loaded += 1
                if not data_monitor:
                    data_monitor = len(subs)
                else:
                    assert data_monitor == len(subs)
                item = transform_fuc(subs,self)
                if item:
                    self.datasets.append(item)
        print(f"{loaded} paris loaded. {len(self.datasets)} are valid. Rate {1.0 * len(self.datasets)/loaded:.4f}")
    def __len__(self):
        return len(self.datasets)
    def __getitem__(self,idx):
        return self.datasets[idx]

3.测试

train_path = '../../data/fra2eng/fra_eng.dev'
dev_path = '../../data/fra2eng/fra_eng.dev'
src_vocab_file = '../../data/fra2eng/src_vocab_file'
tgt_vocab_file = '../../data/fra2eng/tgt_vocab_file'
src_vocab_size = 40000
tgt_vocab_size = 40000
max_src_length = 50
max_tgt_length = 50
batch_size = 20
src_vocab_list = VocabField.load_vocab(src_vocab_file)
tgt_vocab_list = VocabField.load_vocab(tgt_vocab_file)
src_vocab = VocabField(src_vocab_list,vocab_size = src_vocab_size)
tgt_vocab = VocabField(tgt_vocab_list,vocab_size = tgt_vocab_size)
pad_id = tgt_vocab.word2idx[tgt_vocab.pad_token]
trans_data = TranslateData()
train_set = DialogDataset(
    train_path,
    trans_data.translate_data,
    src_vocab,
    tgt_vocab,
    max_src_length = max_src_length,
    max_tgt_length = max_tgt_length
)
trainloader = DataLoader(
    train_set,
    batch_size = 20,
    shuffle = False,
    drop_last = True,
    collate_fn = trans_data.collate_fn
)
#list(trainloader)[0]
dev_set = DialogDataset(
    dev_path,
    trans_data.translate_data,
    src_vocab,
    tgt_vocab,
    max_src_length = max_src_length,
    max_tgt_length = max_tgt_length
)
dev_loader = DataLoader(
    dev_set,
    batch_size = 15,
    shuffle = False,
    collate_fn = trans_data.collate_fn
)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 答:torch.utils.data.datasetPyTorch中用于加载和预处理数据的一种常用数据集类型,它可以帮助您定义数据和标签的转换方式,以及如何以批量方式加载数据。它还可以自动以多核方式加载数据,并可以分片以支持分布式训练。 ### 回答2: torch.utils.data.datasetPyTorch中的一个重要模块,用于处理和加载数据集。它提供了一种简单的方式来组织和操作数据集,并配合使用torch.utils.data.DataLoader来实现数据的批量加载,可方便地用于训练深度学习模型。 torch.utils.data.dataset有两个主要的子类:Dataset和IterableDatasetDataset是所有torch.utils.data.dataset类的基类,它表示一个数据集,提供了获取数据的接口。我们可以通过继承Dataset类并实现其抽象方法__getitem__()和__len__()来自定义自己的数据集。 __getitem__()方法的作用是根据给定的索引返回一个样本,我们可以通过索引或切片符号对数据集进行访问。__len__()方法返回数据集的总样本数,方便后续使用。通过这两个方法,我们可以很方便地获取到数据集中的样本。 IterableDataset则是实现了__iter__()方法的数据集,可以按需生成样本数据,而不需要事先明确知道样本数量。这在一些需要动态生成的数据集中非常有用,比如使用迭代方法读取大型数据集。 除了这两个主要的子类,torch.utils.data.dataset还提供了一些常用的数据集,如TensorDataset、Subset等,它们提供了更加灵活和方便的数据加载方式。 总结来说,torch.utils.data.dataset模块为我们提供了一种高效、灵活的方式来组织和加载数据集,为深度学习的训练和评估提供了基础支持。通过自定义数据集和使用现有的数据集类,我们可以轻松地准备和加载训练数据,方便地进行模型训练。 ### 回答3: torch.utils.data.datasetPyTorch提供的一个用于处理数据集的工具类,主要用于将数据集加载到PyTorch中进行训练和测试。 torch.utils.data.dataset是一个抽象类,我们可以通过继承这个类来创建自定义的数据集类。我们需要重写两个方法:__len__()和__getitem__()。__len__()方法返回数据集中样本的数量,而__getitem__()方法通过索引可以返回对应位置的样本。 通过实现自定义的数据集类,我们可以将数据集加载到PyTorch中。然后,我们可以使用torch.utils.data.DataLoader来实现数据集的批量加载和并行处理。DataLoader是一个迭代器,可以按批次加载数据,并提供多个线程进行数据预处理。 torch.utils.data.dataset还提供了一些常用的数据集类,例如torchvision.datasets.ImageFolder用于处理图像数据集,能够根据文件夹的层级结构自动将图像和对应的标签进行加载。torchvision.datasets.CIFAR10和torchvision.datasets.CIFAR100用于加载CIFAR-10和CIFAR-100数据集。 使用torch.utils.data.dataset的好处是它提供了灵活的接口,可以方便地加载和处理各种类型的数据集。同时,它能够与其他PyTorch组件无缝配合,如模型和优化器,从而构建完整的训练流程。这样,我们可以更方便地进行数据处理和模型训练,提高开发效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值