一文读懂Datapipeline

在深度学习的应用中,数据是模型训练和评估的基础。有效地管理和加载数据对于构建高效的神经网络模型至关重要。PyTorch,作为一个广泛使用的深度学习框架,提供了一套强大的工具来帮助开发者处理数据。DatasetDataLoader是PyTorch中用于构建数据管道的核心组件,它们使得数据的加载、预处理和批量化变得更加简单和高效。

Dataset是一个抽象类,它定义了数据集的接口,允许开发者以统一的方式访问数据。通过实现__len__方法,Dataset提供了数据集的大小信息,而__getitem__方法则允许按索引获取数据集中的单个样本。这种设计使得Dataset可以被看作是一个动态数组,可以根据需要获取数据集中的任何部分。

DataLoader则是一个迭代器,它负责将Dataset中的数据以批次(batch)的形式输出。通过DataLoader,我们可以轻松地控制批次的大小,设置随机采样或顺序采样策略,并在必要时使用多进程来加速数据的加载。此外,DataLoader还提供了一个collate_fn参数,允许开发者自定义数据整理函数,以处理复杂的数据类型或特殊的数据需求。

使用DatasetDataLoader,我们可以构建灵活且高效的数据管道,为模型的训练和评估提供强有力的支持。这不仅简化了数据管理的工作,还提高了代码的可读性和可维护性。在接下来的内容中,我们将详细介绍如何使用这两个工具来构建自己的数据管道,以及如何通过自定义函数来满足特定的数据处理需求。

自学笔记,欢迎交流;

0. DataSet和DataLoader介绍

Pytorch 使用 Dataset和DataLoader这两个工具来构建数据管道

DataSet定义了数据集的内容,类似于列表的数据结构,具有确定的长度,能够索引获取长度;

DataLoader定义了按batch加载数据集的方法,是一个可迭代对象,每次迭代输出一个batch的数据;

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需要的输入形式的方法,并能够使用多进程读取数据;

通常只需要实现dataset的__len__和__getitem__方法,就可以构建自己数据集,并且用默认数据管道进行加载;

复杂数据,需要设计dataloader中的collate_fn方法将一个批次的数据整理成模型需要的输入形式;

DataSet和DataLoader原理

1.1 如何获取一个batch数据 : 总样 —— 抽样 —— 取样值 —— 聚合

(假设数据集的特征和标签分别为张量X和Y,数据集可以表示为(X, Y), batch为M)

  1. 先要获得数据的长度N。假设N=1000
  2. 然后从0—N-1的范围中抽样出M个数,假设M=4,那么拿到的结果就是 indices = [1,4,8,9](4个随机数)
  3. 接着我们去数据集中取这M个数分别对应的下标,这里拿到一个元组列表 samples = [(X[1], Y[1]), (X[4], Y[4]), (X[8], Y[8]), (X[9], Y[9])]
  4. 最后将结果整理成两个张量作为输出,故最后的结果是两个张量 batch = (features, labels)

其中 features = torch.stack([[X][1], X[4], X[8], X[9]]) ; labels = torch.stack([Y[1], Y[4], Y[8], Y[9]]

1.2 DataSet和DataLoader的功能分工

步骤1获取长度是DataSet中的__len__方法获取的

步骤2从0—N-1的范围中抽样M个数是DataLoader的sampler和batch_sampler参数指定的;

sampler 参数指定抽样方法,一般不需要设置,程序默认在DataLoader的参数shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样;

batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。

步骤3的核心逻辑就是根据下标取值,由DataSet中的__getitem__实现;

步骤4是由DataLoader的参数collate_fn指定,一般情况下无需设置;

一般使用方法如下

import torch   
from torch.utils.data import TensorDataset,Dataset,DataLoader  
from torch.utils.data import RandomSampler,BatchSampler   
  
  
ds = TensorDataset(torch.randn(1000,3),  
                   torch.randint(low=0,high=2,size=(1000,)).float())  # 步骤1 
dl = DataLoader(ds,batch_size=4,drop_last = False)  # 步骤2 3 4
features,labels = next(iter(dl))  # 取数据
print("features = ",features )  # 打印查看
print("labels = ",labels )

DataLoader内部调用方式步骤拆解

# step1: 确定数据集长度 (Dataset的 __len__ 方法实现)  
ds = TensorDataset(torch.randn(1000,3),  
                   torch.randint(low=0,high=2,size=(1000,)).float())  
print("n = ", len(ds)) # len(ds)等价于 ds.__len__()  
  
# step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)  
sampler = RandomSampler(data_source = ds)  
batch_sampler = BatchSampler(sampler = sampler,   
                             batch_size = 4, drop_last = False)  
for idxs in batch_sampler:  
    indices = idxs  
    break   
print("indices = ",indices)  
  
# step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)  
batch = [ds[i] for i in  indices]  #  ds[i] 等价于 ds.__getitem__(i)  
print("batch = ", batch)  
  
# step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)  
def collate_fn(batch):  # 聚合函数
    features = torch.stack([sample[0] for sample in batch])  
    labels = torch.stack([sample[1] for sample in batch])  
    return features,labels   
  
features,labels = collate_fn(batch)  
print("features = ",features)  
print("labels = ",labels) 

1.3 Dataset和DataLoader的核心源码

import torch   
class Dataset(object):  
    def __init__(self):  
        pass  
      
    def __len__(self):  # 获取长度信息
        raise NotImplementedError  
          
    def __getitem__(self,index):   # 获取索引信息
        raise NotImplementedError  
          
  
class DataLoader(object):  
    def __init__(self,dataset,batch_size,collate_fn = None,shuffle = True,drop_last = False):  
        self.dataset = dataset  
        self.sampler =torch.utils.data.RandomSampler if shuffle else \  
           torch.utils.data.SequentialSampler  # 抽样方式 随机/顺序
        self.batch_sampler = torch.utils.data.BatchSampler  
        self.sample_iter = self.batch_sampler(  
            self.sampler(self.dataset),  # 索引取值
            batch_size = batch_size,drop_last = drop_last)  # 要不要最后一组batch
        self.collate_fn = collate_fn if collate_fn is not None else \  # 聚合函数
            torch.utils.data._utils.collate.default_collate  
          
    def __next__(self):  
        indices = next(iter(self.sample_iter))  
        batch = self.collate_fn([self.dataset[i] for i in indices])  
        return batch  
      
    def __iter__(self):  
        return self 

测试

class ToyDataset(Dataset):  
    def __init__(self,X,Y):  
        self.X = X  
        self.Y = Y   
    def __len__(self):  
        return len(self.X)  
    def __getitem__(self,index):  
        return self.X[index],self.Y[index]  
      
X,Y = torch.randn(1000,3),torch.randint(low=0,high=2,size=(1000,)).float()  
ds = ToyDataset(X,Y)  
  
dl = DataLoader(ds,batch_size=4,drop_last = False)  
features,labels = next(iter(dl))   
print("features = ",features )  
print("labels = ",labels )

2. 使用DataSet创建数据集

Dataset创建数据集常用的方法有:

  • 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
  • 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
  • 继承 torch.utils.data.Dataset 创建自定义数据集。

此外,还可以通过

  • torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
  • 调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。

2.1 根据Tensor创建数据集

import numpy as np   
import torch   
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split   

# 根据Tensor创建数据集  
  
from sklearn import datasets   
iris = datasets.load_iris()  # 导入鸢尾花数据 
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))  # 转成tensor数据
  
# 分割成训练集和预测集  
n_train = int(len(ds_iris)*0.8)  
n_val = len(ds_iris) - n_train  
ds_train,ds_val = random_split(ds_iris,[n_train,n_val])  
  
print(type(ds_iris))  
print(type(ds_train)) 

# 使用DataLoader加载数据集  
# 最简单的给两个参数,一个数据集,一个batch_size大小
dl_train,dl_val = DataLoader(ds_train,batch_size = 8),DataLoader(ds_val,batch_size = 8)  
  
for features,labels in dl_train:  
    print(features,labels)  
    break  

# 演示加法运算符(`+`)的合并作用  
  
ds_data = ds_train + ds_val  
  
print('len(ds_train) = ',len(ds_train))  
print('len(ds_valid) = ',len(ds_val))  
print('len(ds_train+ds_valid) = ',len(ds_data))  
  
print(type(ds_data))  

2.2 根据图片目录创建图片数据集

import numpy as np   
import torch   
from torch.utils.data import DataLoader  
from torchvision import transforms,datasets   

#演示一些常用的图片增强操作  

from PIL import Image  
img = Image.open('./data/cat.jpeg')  
img  

# 随机数值翻转  
transforms.RandomVerticalFlip()(img)  

#随机旋转  
transforms.RandomRotation(45)(img)  

# 定义图片增强操作  
  
transform_train = transforms.Compose([  
   transforms.RandomHorizontalFlip(), #随机水平翻转  
   transforms.RandomVerticalFlip(), #随机垂直翻转  
   transforms.RandomRotation(45),  #随机在45度角度内旋转  
   transforms.ToTensor() #转换成张量  
  ]  
)   
  
transform_valid = transforms.Compose([  # 评估数据不做增强
    transforms.ToTensor()  
  ]  
)  

# 根据图片目录创建数据集  
  
def transform_label(x):  
    return torch.tensor([x]).float()  
  
ds_train = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/train/",  
            transform = transform_train,target_transform= transform_label)  
ds_val = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/test/",  
                              transform = transform_valid,  
                              target_transform= transform_label)  
  
  
print(ds_train.class_to_idx)  
  
# 使用DataLoader加载数据集  
  
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)  
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)  
  
  
for features,labels in dl_train:  
    print(features.shape)  
    print(labels.shape)  
    break  

2.3 创建自定义数据集

下面我们通过另外一种方式,即继承 torch.utils.data.Dataset 创建自定义数据集的方式来对 cifar2构建 数据管道。

from pathlib import Path   
from PIL import Image   

# 自定义数据集,继承Dataset  
class Cifar2Dataset(Dataset):  
     # 传入文件目录与数据增强方式
    def __init__(self,imgs_dir,img_transform):  
        # 这里是图像的目录,目录中是图片及其编号
        self.files = list(Path(imgs_dir).rglob("*.jpg"))  
        self.transform = img_transform  
          
    def __len__(self,): 
        # 获取长度 
        return len(self.files)  
      
    def __getitem__(self,i):  
        # 逐个操作,根据文件的编号获取相对应的数据与标签信息
        file_i = str(self.files[i])  
        img = Image.open(file_i)  
        tensor = self.transform(img)  
        label = torch.tensor([1.0]) if  "1_automobile" in file_i else torch.tensor([0.0])  
        return tensor,label   
      
      
train_dir = "./eat_pytorch_datasets/cifar2/train/"  
test_dir = "./eat_pytorch_datasets/cifar2/test/"  

# 定义图片增强  
transform_train = transforms.Compose([  
   transforms.RandomHorizontalFlip(), #随机水平翻转  
   transforms.RandomVerticalFlip(), #随机垂直翻转  
   transforms.RandomRotation(45),  #随机在45度角度内旋转  
   transforms.ToTensor() #转换成张量  
  ]  
)   
  
transform_val = transforms.Compose([  
    transforms.ToTensor()  
  ]  
)  

ds_train = Cifar2Dataset(train_dir,transform_train)  
ds_val = Cifar2Dataset(test_dir,transform_val)  
  
  
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)  
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)  
  
  
for features,labels in dl_train:  
    print(features.shape)  
    print(labels.shape)  
    break  

3. 使用DataLoader加载数据集

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

DataLoader的函数签名如下。

DataLoader(  
    dataset,  
    batch_size=1,  
    shuffle=False,  
    sampler=None,  
    batch_sampler=None,  
    num_workers=0,  
    collate_fn=None,  
    pin_memory=False,  
    drop_last=False,  
    timeout=0,  
    worker_init_fn=None,  
    multiprocessing_context=None,  
)  

一般情况下,我们仅仅会配置 dataset, batch_size, shuffle, num_workers,pin_memory, drop_last这六个参数,

有时候对于一些复杂结构的数据集,还需要自定义collate_fn函数,其他参数一般使用默认值即可。

DataLoader除了可以加载我们前面讲的 torch.utils.data.Dataset 外,还能够加载另外一种数据集 torch.utils.data.IterableDataset。

和Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。它更加复杂,一般较少使用。

  • dataset : 数据集
  • batch_size: 批次大小
  • shuffle: 是否乱序
  • sampler: 样本采样函数,一般无需设置。
  • batch_sampler: 批次采样函数,一般无需设置。
  • num_workers: 使用多进程读取数据,设置的进程数。
  • collate_fn: 整理一个批次数据的函数。
  • pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
  • drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
  • timeout: 加载一个数据批次的最长等待时间,一般无需设置。
  • worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。

参考文章 实践教程|源码级理解Pytorch中的Dataset和DataLoader - CV技术指南(公众号) - 博客园

  • 47
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值