一文读懂Dataset, DataLoader及collate_fn, Sampler等参数

数据预处理DataLoader及各参数详解

pytorch关于数据处理的功能模块均在torch.utils.data 中,pytorch输入数据PipeLine一般遵循一个“三步走”的策略,操作顺序是这样的:

继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__返回数据集样本的数量,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。在实现自定义类时,一般需要对图像数据做增强处理,和标签处理,__getitem__返回图像和对应label,图像增强的方法可以使用pytorch自带的torchvision.transforms内模块,也可以使用自定义或者其他第三方增强库。

② 导入 DataLoader类,传入参数(上面自定义类的对象) 创建一个DataLoader对象。

③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练

dataset = MyDataset()           # 第一步:构造Dataset对象
dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象

num_epoches = 100
for epoch in range(num_epoches):# 第三步:逐步迭代数据
    for img, label in dataloader:
        # 训练代码

pytorch内部默认的数据处理类有如下:


class Dataset(object):

class IterableDataset(Dataset):

class TensorDataset(Dataset): #  封装成tensor的数据集,每一个样本都通过索引张量来获得。

class ConcatDataset(Dataset): #  连接不同的数据集以构成更大的新数据集

class Subset(Dataset):  # 获取指定一个索引序列对应的子数据集

class ChainDataset(IterableDataset):

一般能用到的是ConcatDataset, Subset,其他不常用。

可迭代对象的创建方式

  • 方法一:在python中凡是具有__iter__的方法的类,都是可迭代的类。可迭代类创建的对象实现了__iter__方法,因此就是可迭代对象。

    from collections import Iterable, Iterator
    
    class Student(object):
        def __init__(self, score):
            self.score = score
    
        def __iter__(self):
            return iter(self.score)  # return 返回的是一个迭代器,
    
    test = Student([80, 90, 95])
    print(isinstance(test,  Iterable))
    print(isinstance(test,  Iterator))
    for i in test:  # test可迭代对象,但不是迭代器,所以不能next(test)
        print(i)
    for i in test:  # 重复遍历试试看,是否有结果
        print(i)
    print("============")
    test=iter(test)  # 对可迭代对象使用内建函数iter(), 使之成为为迭代器,此时可以next(test)
    print(isinstance(test,  Iterable))
    print(isinstance(test,  Iterator))
    for i in test:
        print(i)
    for i in test:  # 对迭代器重复遍历试试看,有结果过没,
        print(i)  # 没有结果
    """
    True
    False
    80
    90
    95
    80
    90
    95
    ============
    True
    True
    80
    90
    95
    """
    

    从本代码可看出Student类创建的对象是可迭代对象,但不是迭代器(因为没有实现__next__方法),且可以实现重复遍历,而迭代器是无法重复遍历的!

  • 方法二:用list、tuple等容器创建的对象,也都是可迭代对象。如:test=[1,2,3], test就是可迭代对象

迭代器的创建方式:迭代器对象必须同时实现__iter__和__next__方法才是迭代器

  • 方法一:自定义类实现__iter__和__next__方法,对于迭代器来说,__iter__ 返回的是它自身 self__next__ 则是返回迭代器中的下一个值。

    class Student(object):
        def __init__(self, score):
            self.score = score
    
        def __iter__(self):
            return self  # 对于迭代器来说,__iter__ 返回的是它自身self,也就是返回迭代器。
    
        def __next__(self):
            if self.score < 100:
                self.score += 1
                
  • 11
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值