pytorch中的数据加载:Dataset与DataLoader

Dataset和DataLoader

在深度学习中,往往需要经过大量的样本对网络参数进行训练,才能得到一个鲁棒性高的模型,而这么大量的样本,就需要通过mini-batch对图片进行迭代输入进网络进行训练,在pytorch中,通常使用Dataset和DataLoader这两个工具来构建数据管道,进行加载数据以及batch的迭代。

Dataset定义了数据集的内容,它是一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。

而DataLoader定义了按batch加载数据集的方法,每次迭代输出一个batch的数据,并且能够使用多进程读取数据。

在绝大部分情况下,用户只需实现Dataset的__len__方法和__getitem__方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

一、Dataset与DataLoader实现逻辑

1.获取一个batch数据的步骤

1.确定整个数据集的长度n;

拿图片数据集举例子,通常图片数量很多,想要把图片像素矩阵全部加载进内存,再传给网络,这不现实,因为内存会爆炸,因额通常会把数据集进行处理,处理成txt文本的格式,文本中每一行第一个位置是图片的路径,后面跟着图片的label(包括box坐标和类别),通过读取图片路径来读取图片能节省很大的内存空间。
这是数据集的长度n就是这个txt文本的行数,因为每一行代表一张图。

2.然后需要从这n张图中抽样出m张图作为一个batch,m就是batch_size的值。

假定m=8,拿到的就是一个索引值的列表,比如index = [1, 6, 9, 5, 7, 3, 8, 12]

3.接着要通过这个索引列表从数据集中找到这m个索引对应的图片路径和对应的label。

这时拿到的结果是一个元组列表,比如:sample = [(x[1], y[1]), (x[6], y[6]),(x[9], y[9]),(x[5], y[5]),(x[7], y[7]),(x[3], y[3]),(x[8], y[8]),(x[12], y[12])]

4.将取到的图片路径通过opencv读取,再将图片矩阵和label转化为torch需要的张量作为网络的输入。

2.各个步骤的实现方法

1.确定数据集的长度:通过 Dataset的__len__ 方法实现的,在后续例子中会详细见到。
2.抽取一个batch的样本:由 DataLoader的 sampler和 batch_sampler参数指定的。

sampler参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样。

batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。

3.由batch索引获取数据:由 Dataset的 __getitem__方法实现的。
4.转换为tensor:由DataLoader的参数collate_fn指定。一般情况下也无需用户设置。

二、使用Dataset创建数据集

一般我们都是需要创建自定义的数据集,通过继承 torch.utils.data.dataset 创建自定义数据集。

我们这里使用的是经典的猫狗分类的数据集,也是后续手动复现,用于学习pytorch代码结构的demo,直接上代码了。

from torch.utils.data.dataset import Dataset	#通过继承torch.utils.data.dataset创建自定义数据集

class Data_loader(Dataset):	#定义dataset的类
    def __init__(self, train_lines, transform):	#初始化
        super(Data_loader, self).__init__()	#初始化继承Data_loader类
        self.train_lines = train_lines	#传进来的txt文本的所有数据
        self.transform = transform	#定义数据随机增强方法

    def __len__(self):	#__len__方法返回数据集长度
        return len(self.train_lines)	

    def __getitem__(self, idx):	# __getitem__方法通过index获取数据
        img_path = self.train_lines[idx].split(' ')[0]	#得到图片路径
        label = self.train_lines[idx].split(' ')[1]	#得到图片label
        #print(img_path)
        img = cv2.imread(img_path)
        sample = {'image_path':img_path, 'label':label}	#将图片与label组成字典保存为一个sample
        #print(sample['image_path'])

        if self.transform:	#随机增强
            sample = self.transform(sample)
        return sample
        #img = Rescale(output_size=224, img_path=self.train_lines[idx])

使用DataLoader加载数据集

DataLoader函数里可以通过参数batch_size设置batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,能够使用num_workers参数设置多进程读取数据。

DataLoader(
    dataset,  #上一步创建好的dataset
    batch_size=1,  #batch的大小
    shuffle=False,  #是否随机读取数据,一般为设为True
    sampler=None,	
    batch_sampler=None,	
    num_workers=0,  #多进程参数
    collate_fn=None,  #加载数据的格式
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None,
)

一般,我们尝试用的参数即为上述代码中注释的5个参数:dataset、batch_size、shuffle、num_workers、collate_fn,其他参数的含义如下:
sampler: 样本采样函数,一般无需设置。
batch_sampler: 批次采样函数,一般无需设置。
pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
timeout: 加载一个数据批次的最长等待时间,一般无需设置。
worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。

这里使用DataLoader加载上面处理的猫狗数据集,代码如下:

from data_lodar import Data_loader
from torchvision import transforms
from data_augmentation import Rescale,RandomCrop,ToTensor
from torch.utils.data import DataLoader

train_dataset = Data_loader(lines, #lines即为从txt中读入的数据
                            transform=transforms.Compose([ #设置数据增强方式,使用transforms.Compose进行连接
                                Rescale(256), #随机缩放,函数需要自己写,详细看之后的demo
                                RandomCrop(), #随机剪裁
                                ToTensor() #转化为tensor
                            ]))
dataloader = DataLoader(train_dataset, batch_size=batch_size,
                        shuffle=True)

更多技术欢迎加入交流:320297153

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值