Pytorch学习之旅(5)——DataLoader和DataSet

一、DataLoader()类

torch.utils.data.DataLoader():构建可迭代的数据装载器.
class DataLoader(object):
    def __init__(self, 
                 dataset, 
                 batch_size=1, 
                 shuffle=False, 
                 sampler=None,
                 batch_sampler=None, 
                 num_workers=0, collate_fn=None,
                 pin_memory=False, 
                 drop_last=False, 
                 timeout=0,
                 worker_init_fn=None, 
                 multiprocessing_context=None)

常用:

DataLoader( dataset,		 => Dataset类,决定数据从哪儿读取及如何读取.
			batch_size=1,	 => 批大小.
			shuffle=False,	 => 每个epoch是否乱序.
			num_workers=0,	 => 是否多进程读取数据.
			drop_last=False, => 当Epoch/BatchSize不为整数时,是否丢弃最后一批Iteration数据.

补充:

Epoch:全部训练样本。
Iteration:一批样本。
BatchSize:一批样本的大小。

有:Iteration = Epoch/BatchSize + 0/1(取决于drop_last的设定)

二、Dataset()类

torch.utils.data.Dataset():Dataset抽象类,子类必须复写 getitem()函数.
class DataLoader(object):

    def __getitem__(self, index):	# 接受索引,返回样本
        raise NotImplementedError


    def __add__(self, other):
        return ConcatDataset([self, other])

使用举例:

class RMBDataset(Dataset): # 子类RMBDataset继承父类Dataset

    def __init__(self, data_dir, transform=None):	# 复写构造函数,此处传入由transforms.Compose()函数返回的,包含图片变换操作的列表list()。
        self.label_name = {'1': 0, '100': 1}
        self.data_info = self.get_img_info(data_dir)  # data_info为自定义的函数
        self.transform = transform 
        
    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值