(4)Dataset和Dataloader详解

0 数据处理流程

在Pytorch中使用Dataset和Dataloader对数据进行处理,便于后续将数据输入到模型当中进行计算和分类等操作。首先让我们来看一下数据在这两个类中是如何流转的,这能够帮助我们更好的了解Dataset和Dataloader的作用方式。流程如下图所示:
在这里插入图片描述
在调用Dataloader的__iter__时,会产生一个迭代器DataloaderIter,进而出发了__next__函数在每一次的迭代中循环取出数据,并转入到Dataset中对数据进行处理,也就是使用__getitem____len__函数,进而将处理好的数据打包返回到模型中进行计算。

1 Dataset

torch.utils.data.Dataset是pytorch中预定义好的一个基类。可以直接调用进行使用,也可以通过重写__getitem____len__函数来实现自己的数据集。

1.1__getitem__

用于实现索引操作,它定义了通过索引访问数据集中的样本的行为。当我们使用索引操作符 [] 时,实际上是调用了 __getitem__ 函数。
__getitem__ 函数接收一个索引作为参数,并返回对应索引位置的数据样本。这个索引可以是整数、切片或任何其他适用的索引形式,具体取决于数据集的实现。例如,对于图像数据集,索引可能是对应图像的编号或者切片来指定一组图像。
__getitem__ 函数内部,通常会包含以下步骤:

  1. 根据索引获取数据样本:根据传入的索引,__getitem__函数会从数据集中获取对应的数据样本。这可以是图像、文本、语音等等,根据数据集的特点来决定。
  2. 数据转换(可选):一旦获得数据样本,__getitem__函数可以对其进行预处理或转换,以便将其转换为模型所需的输入形式。例如,对图像进行缩放、标准化,对文本进行分词、编码等。这个过程是根据实际需求来决定的。
  3. 返回数据样本:最后,__getitem__ 函数会返回已处理的数据样本,供调用者使用。通常情况下,返回的样本是一个特征向量或张量。

1.2__len__

在 Dataset 中,__len__ 函数用于返回数据集中样本的数量。它定义了获取数据集长度的行为,主要是为了支持 len() 函数的调用。
当我们使用 len() 函数来获取数据集的长度时,实际上是调用了 __len__ 函数。__len__ 函数没有任何参数,仅返回一个整数,表示数据集中的样本数量。

2 Dataloader

与Dataset相同,Dataloader类也是Pytorch预定义好的一个基类,是Dataset和sampler的组合。
在使用 DataLoader 时,可以传递一些参数来配置数据加载和批处理的方式。下面是常见的参数及其作用:

  1. dataset:指定要加载的数据集。这通常是一个实现了 Dataset 接口的对象。
  2. batch_size:指定每个批次中的样本数量。DataLoader 会按照指定的批次大小从数据集中一次性加载一批样本。
  3. shuffle:指定是否在每个 epoch(数据集遍历一次)之前打乱数据集的顺序。这在训练模型时很常见,以避免模型对样本的顺序产生依赖。
  4. num_workers:指定用于数据加载的子进程数量。通过提供多个子进程,可以并行地加载和预处理数据,从而加快数据加载的速度。
  5. collate_fn:指定如何对批次中的样本进行收集和处理。它接收一个批次的样本列表,并返回一个包含这些样本的批次张量或其他数据结构。这对于数据集中样本具有不同尺寸或形状的情况特别有用。
  6. pin_memory:指定是否将张量数据存储在主机的固定内存中,以加速 GPU 数据的传输。这对于使用 GPU 训练模型时可以提高数据传输效率。
  7. drop_last:指定在数据集样本数量不能被批次大小整除时是否丢弃最后一个不完整的批次。如果设置为 True,则最后一个批次将会被丢弃。

在调用 DataLoader 时,可以根据任务需求适当调整这些参数的值。例如,选择合适的批次大小、是否打乱数据集、是否利用多进程加速数据加载等,以最优化数据处理的性能和效果。

3 手写dataset和dataloader简单实现

import os
import random
import numpy as np

def read_data(file):
    with open(file,encoding="utf-8") as f:
        all_data = f.read().split("\n")
    all_text,all_label = [],[]
    for data in all_data:
        data_s = data.split("\t")
        if len(data_s) != 2:
            continue
        text,label = data_s
        try:
            # label = int(label)
            all_label.append(label)
            all_text.append(text)
        except:
            print("标签报错!")
    assert len(all_text) == len(all_label),"数据和标签长度都不一样!玩个球啊!"
    return all_text,all_label

class Dataset():
    def __init__(self,all_text,all_label,batch_size,word_2_index,label_2_index):
        self.all_text = all_text
        self.all_label = all_label
        self.batch_size = batch_size
        self.word_2_index = word_2_index
        self.label_2_index = label_2_index
    #迭代器每一个epoch会触发一次
    def __iter__(self):
        dataloader = Dataloader(self)
        return dataloader
    #使用下标取数据时会触发这个函数
    def __getitem__(self,index):
        text = self.all_text[index][:max_len]
        text_idx = [self.word_2_index[w] for w in text]
        text_idx = text_idx + [0]*(max_len-len(text_idx))
        label = self.all_label[index]
        label_idx = self.label_2_index[label]
        return text_idx, label_idx

class Dataloader():
    def __init__(self,dataset):
        self.dataset = dataset
        self.cursor = 0

        #打乱数据索引
        self.random_idx = [i for i in range(len(self.dataset.all_text))]
        random.shuffle(self.random_idx)
    #循环取数据
    #在每一个epoch中只要取出一次就触发一次
    def __next__(self):
        if self.cursor >= len(self.dataset.all_text):
            raise StopIteration
        #数据长度为11,取到最后一条的时候,+self.cursor+self.dataset.batch_size,超过原本的数据长度,需要再判断,不能超过数据的长度
        #进行到这一步出现下标,触发了dataset中的__getitem__函数
        batch_data = [self.dataset[i] for i in range(self.cursor,min(self.cursor+self.dataset.batch_size,len(self.dataset.all_text)))]
        if batch_data:
            text_idx, label_idx = zip(*batch_data)
            self.cursor += len(batch_data)
            return np.array(text_idx), np.array(label_idx)
        else:
            raise StopIteration 
#对文字用数据表示
def build_word_2_index(all_text):
    word_2_index = {'PAD':0}
    for text in all_text:
        for w in text:
            word_2_index[w] = word_2_index.get(w, len(word_2_index))
    return word_2_index

#对标签用数字表示
def build_label_2_index(all_lablel):
    return {k:i for i,k in enumerate(set(all_label),start=0) }

class Model():
    def __init__(self):
        pass
    def forward(self,x):
        return x
    def __call__(self,x):
        return self.forward(x)
if __name__ == "__main__":

    all_text,all_label = read_data(os.path.join("dataset-dataloder\data","train0.txt"))
    
    epoch = 1
    batch_size = 2
    max_len = 10

    word_2_index = build_word_2_index(all_text)
    label_2_index = build_label_2_index(all_label)
    #运行这一步时,只运行了__init__这个函数,对类中的数据进行了初始化
    train_dataset = Dataset(all_text,all_label,batch_size,word_2_index,label_2_index)  
    #运行到每一个循环的时候,触发迭代器__iter__,转入到dataloader中对数据进行操作,也只运行__init__函数里面的内容  
    for e in range(epoch):
        #进入迭代器所触发的函数__next__中取数据
        for batch_data in train_dataset:
            batch_text_idx, batch_label_idx = batch_data
            Model.forward(batch_text_idx)
            #print(batch_text_idx)
            #print(batch_label_idx)

4 调包简单实现

以下是一个使用Python中的torch库,简单实现一个自定义的DatasetDataLoader的示例:

import torch
from torch.utils.data import Dataset, DataLoader
# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)    
    def __getitem__(self, index):
        return self.data[index]
# 创建数据集实例
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 创建数据加载器实例
batch_size = 2
shuffle = True
num_workers = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
# 使用数据加载器迭代数据集
for batch in dataloader:
    print(batch)

在这个示例中,首先定义了一个名为MyDataset的自定义数据集类,它继承了torch.utils.data.Dataset类并重写了__len____getitem__方法。__len__方法返回数据集的样本数量,__getitem__方法用于获取指定索引的样本。
然后,我们创建了一个名为dataset的数据集实例,将我们的数据列表传递给它。
接下来,我们使用DataLoader类创建了一个名为dataloader的数据加载器实例,指定了批次大小、是否打乱数据集、以及加载数据的子进程数量。
最后,我们使用dataloader迭代数据集,每次返回一个批次的样本。在这个示例中,我们简单地打印了每个批次的内容。

  • 20
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值