Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

 

Pytorch的数据读取主要包含三个类:

  1. Dataset
  2. DataLoader
  3. DataLoaderIter

这三者大致是一个依次封装的关系: 1.被装进2., 2.被装进3.

一. torch.utils.data.Dataset

是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  1. __getitem__()
  2. __len__()

第一个最为重要, 即每次怎么读数据. 以图片为例:


 
 
  1. def __getitem__(self, index):
  2. img_path, label = self.data[index].img_path, self.data[index].label
  3. img = Image.open(img_path)
  4. return img, label

值得一提的是, pytorch还提供了很多常用的transform, 在torchvision.transforms 里面, 本文中不多介绍, 常用的有Resize , RandomCrop , Normalize , ToTensor (这个极为重要, 可以把一个PIL或numpy图片转为torch.Tensor, 但是好像对numpy数组的转换比较受限, 所以这里建议在__getitem__()里面用PIL来读图片, 而不是用skimage.io).

 第二个比较简单, 就是返回整个数据集的长度:


 
 
  1. def __len__(self):
  2. return len( self.data)

二. torch.utils.data.DataLoader

类定义为:


 
 
  1. class torch.utils.data.DataLoader(
  2. dataset,
  3. batch_size= 1,
  4. shuffle= False,
  5. sampler=None,
  6. batch_sampler=None,
  7. num_workers= 0,
  8. collate_fn=< function default_collate>,
  9. pin_memory= False,
  10. drop_last= False
  11. )

可以看到, 主要参数有这么几个:

  1. dataset : 即上面自定义的dataset.
  2. collate_fn: 这个函数用来打包batch
  3. num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据

这个类其实就是下面将要讲的DataLoaderIter的一个框架, 一共干了两件事:

  1. 定义了一堆成员变量, 到时候赋给DataLoaderIter,
  2. 然后有一个__iter__() 函数, 把自己 "装进" DataLoaderIter 里面.

 
 
  1. def __iter__(self):
  2. return DataLoaderIter( self)

三. torch.utils.data.dataloader.DataLoaderIter

上面提到, DataLoader就是DataLoaderIter的一个框架, 用来传给DataLoaderIter 一堆参数, 并把自己装进DataLoaderIter 里。其实到这里就可以满足大多数训练的需求了, 比如


 
 
  1. class CustomDataset(Dataset):
  2. # 自定义自己的dataset
  3. dataset = CustomDataset()
  4. dataloader = Dataloader(dataset, ...)
  5. for data in dataloader:
  6. # training...

在for 循环里, 总共有三点操作:

  1. 调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter
  2. 反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及sample 的方法等.
  3. 当数据读完后, __next__()抛出一个StopIteration异常, for循环结束, dataloader 失效.

四. 又一层封装

其实上面三个类已经可以搞定了, 仅供参考


 
 
  1. class DataProvider:
  2. def __init__(self, batch_size, is_cuda):
  3. self.batch_size = batch_size
  4. self.dataset = Dataset_triple( self.batch_size,
  5. transform _=transforms.Compose(
  6. [transforms.Scale([ 224, 224]),
  7. transforms.ToTensor(),
  8. transforms.Normalize(mean=[ 0. 485, 0. 456, 0. 406],
  9. std=[ 0. 229, 0. 224, 0. 225])]),
  10. )
  11. self.is_cuda = is_cuda # 是否将batch放到gpu上
  12. self.dataiter = None
  13. self.iteration = 0 # 当前epoch的batch数
  14. self.epoch = 0 # 统计训练了多少个epoch
  15. def build(self):
  16. dataloader = DataLoader( self.dataset, batch_size= self.batch_size, shuffle=True, num_workers= 0, drop_last=True)
  17. self.dataiter = DataLoaderIter(dataloader)
  18. def next(self):
  19. if self.dataiter is None:
  20. self.build()
  21. try:
  22. batch = self.dataiter. next()
  23. self.iteration += 1
  24. if self. is_cuda:
  25. batch = [batch[ 0].cuda(), batch[ 1].cuda(), batch[ 2].cuda()]
  26. return batch
  27. except StopIteration: # 一个epoch结束后reload
  28. self.epoch += 1
  29. self.build()
  30. self.iteration = 1 # reset and return the 1st batch
  31. batch = self.dataiter. next()
  32. if self. is_cuda:
  33. batch = [batch[ 0].cuda(), batch[ 1].cuda(), batch[ 2].cuda()]
  34. return batch

感谢以下链接提供的参考:

https://zhuanlan.zhihu.com/p/30934236

https://blog.csdn.net/u014380165/article/details/79167753

转载:https://blog.csdn.net/weixin_39739616/article/details/83824944
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值