【Pytorch】简析数据批量化处理类Dataset和DataLoader

在深度学习中,在将原始数据进行清理、规范化和编码后,就需要将数据进行序列化和批量化,而Pytorch提供这两项功能的类分别为DatasetDataLoader

1. Dataset类

Dataset类是将数据进行序列化封装的类,我们在为每个具体问题定制合适的Dataset子类时,仅需要继承该父类,同时覆写__init____getitem____len__三个魔鬼方法即可:

  • __init__:类的初始化,在用来用于设置文件路径、导入文件和定义必要的变量。
  • __getitem__:提供一个切片的方法,可以根据输入的index,获取对应的一个数据。
  • __len__:用于统计数据样本的总量。
2. DataLoader类

在用Dataset类对数据封装完后进行训练和测试时,还需要对数据进行批量化处理,以供每个min-batch的数据。该类一般无需改写,直接加载对应的Dataset类,并设置相应的参数即可生成一个包含min-batch数据的可迭代对象。

MyDataLoader = DataLoader(dataset=MyDataset, batch_size=512, shuffle=True, num_workers=4)

如上面的例子,DataLoader的四个主要参数定义了数据批量化的主要属性,具体包括:

  • dataset: Dataset子类,即序列化好的数据。
  • batch_size: min-batch的尺寸。
  • shuffle: 在每个epoch取样前,是否先打乱数据顺序。
  • num_workers:所用的子进程数,默认为0,即仅用主进程。

除此之外,还有两个参数可能会用到:

  • sampler: Sample子类,定义了数据进行采样的方式。之前的shuffle=True其实也提供了一种采样方法,所以当设置sampler参数时,必须设置shuffle=False
  • collate_fn: 用于对Dataset中采样得到的每个mini-batch数据进行后处理,从而提供更好的模型输入数据,其取值为一个外部定义的可调用函数。也就是说,设置该值后,真正迭代输出的值是经过该函数处理后的返回值。该函数的具体使用可参照博文
3. 简单示例
import torch
from torch.utils.data import Dataset, DataLoader

A = torch.randn(128, 3)
C = torch.randn(128, 1)

# 1. 用Dataset封装数据集,仅做示范,实际可直接用TensorDataset封装
class MyDataset(Dataset):
    def __init__(self, x, y):
        assert x.size(0)==y.size(0)
        self.x, self.y = x, y
        
    def __getitem__(self, idx):
        return (self.x[idx], self.y[idx])
    
    def __len__(self):
        return self.x.size(0)

# 2. 用DataLoader定义数据批量迭代器
MyDataLoader = DataLoader(dataset=dataset, shuffle=True, batch_size=4)   

for data_iter in MyDataLoader:
	# 进行训练或预测
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值