关于pytorch里DataLoader的理解

目录

一、python迭代器生成器基础讲解

1.1可迭代对象Iterable

1.2迭代器Iterator

1.3for in 的本质流程

1.4 getitem

1.5 yield 生成器

二、DataLoader的基础实现

三、整体框架的讲解


一、python迭代器生成器基础讲解

1.1可迭代对象Iterable

表示该对象可迭代,并不一定是一个数据类型,如字典,字符串,列表等,它也可以是一个实现了__iter__方法的类。

from collections.abc import Iterable, Iterator

class A(object):
    def __init__(self):
        self.a = [1, 2, 3]

    def __iter__(self):
        # 此处返回啥无所谓
        return self.a

cls_a = A()
#  True
print(isinstance(cls_a, Iterable))

如果对象是Iterable,依然无法用for循环遍历,因为Iterable仅仅是提供了一种抽象规范接口。

1.2迭代器Iterator

如果一个对象是迭代器,那么它肯定是可迭代的,但是如果一个对象是可迭代的,它不一定是迭代器。实现了 __next__ 和 __iter__ 方法的类才能称为迭代器,就可以被 for 遍历了。

class A(object):
    def __init__(self):
        self.index = -1
        self.a = [1, 2, 3]

    # 必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
    # 因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
    def __iter__(self):
        return self

    def __next__(self):
        self.index += 1
        if self.index < len(self.a):
            return self.a[self.index]
        else:
            # 抛异常,for 内部会自动捕获,表示迭代完成
            raise StopIteration("遍历完了")

cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
print(isinstance(iter(cls_a), Iterator)) # True

for a in cls_a:
    print(a)
# 打印 1 2 3

1.3for in 的本质流程

for.....in...被python编译器编译后,如下

# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
cls_a = iter(cls_a)
while True:
    try:
        # 然后调用对象的 __next__ 方法,不断返回元素
        value = next(cls_a)
        print(value)
    # 如果迭代完成,则捕获异常即可
    except StopIteration:
        break

可见,任何一个对象要能被for遍历,必须实现__iter__和__next__两个方法。

list是可迭代对象,但是没next方法,为什么可以实现for循环遍历。list内部的iter方法的内部实现了next方法。

所以得到:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 __iter__ 和 __next__ 方法(对象必然是 Iterator),还是只实现 __iter__(不是 Iterator),但是内部间接返回了具备 __next__ 对象的类,都是可行的

1.4 getitem

上面说过for in本质就是调用__iter__和__next__方法,实际上还有一种更简单的方法,__getitem__方法就可以让对象实现迭代功能。实际上任何一个类,只要实现了__getitem__方法,那么当调用iter(类实例)时候会自动具备__iter__和__next__方法。__getitem__ 实际上是属于 iternext方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。

class A(object):
    def __init__(self):
        self.a = [1, 2, 3]

    def __getitem__(self, item):
        return self.a[item]

cls_a = A()
print(isinstance(cls_a, Iterable))  # False
print(isinstance(cls_a, Iterator))  # False
print(dir(cls_a))  # 仅仅具备 __getitem__ 方法

cls_a = iter(cls_a)
print(dir(cls_a))  # 具备 __iter__ 和 __next__ 方法

print(isinstance(cls_a, Iterable))  # True
print(isinstance(cls_a, Iterator))  # True

# 等价于 for .. in ..
while True:
    try:
        # 然后调用对象的 __next__ 方法,不断返回元素
        value = next(cls_a)
        print(value)
    # 如果迭代完成,则捕获异常即可
    except StopIteration:
        break

# 输出: 1 2 3

如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 __len__ 方法即可。

此时我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。

1.5 yield 生成器

生成器是一个在行为上和迭代器非常类似的对象,两者功能差不多,但生成器更优雅,只需要用关键字yield来返回。作用于函数上叫生成器函数,调用函数返回一个生成器。

def func():
    for a in [1, 2, 3]:
        yield a

cls_g = func()
print(isinstance(cls_g, Iterator))  # True
print(dir(cls_g))  # 自动具备 __iter__ 和 __next__ 方法

for a in cls_g:
    print(a)

# 输出: 1 2 3

# 一种更简单的写法是用 ()
cls_g = (i for i in [1,2,3])

使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。

总结:在迭代对象基础上,如果实现了 __next__ 方法则是迭代器对象,该对象在调用 next() 的时             候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。

           对于采用语法糖 __getitem__ 实现的迭代器对象,其本身实例既不是可迭代对象,更不是               迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成             迭代器。

          生成器是一种特殊迭代器,但是不需要像迭代器一样实现__iter____next__方法,只需要            使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入            yield 关键字实现。

          对于在类的 __iter__ 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对              象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后            就会自动变成迭代器。

二、DataLoader的基础实现

首先介绍5个基本的对象:

Dataset提供整个数据集的随机访问功能,每次访问都返回单个对象,例如一个对象和一个target。

Sampler提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引。常用的子类是SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引

BatchSampler内部调用Sampler实列,输出指定batch_size个索引,然后将索引作用于Dataset上从而输出batch_size个数据对象,例如batch_size个数据和索引。

Collate_fn用于将batch_size个数据对象在batch维度进行聚合,生成(batch,.....)格式的数据输出。如果待聚合对象是numpy,则自动转化为tensor,此时就可以输入到网络中了。

迭代一次伪代码如下(非迭代器版本)

class DataLoader(object):
    def __init__(self):
        #假设数据长度为100,batch_size是4
        self.dataset=[[img0,target0],[img1,target1],.....[img99,target99]]
        self.sampler=[0,1,2,.....,99]
        self.batch_size=4
        self.index=0

    def collate_fn(self,data):
        #在batch维度聚合数据
        batch_img=torch.Stack(data[0],0)
        batch_target=torch.stack(data[1],0)
        return batch_img,batch_target

    def __next__(self):
        i=0
        batch_index=[]
        while i<self.batch_size:
            #内部会调用sampler对象获取单个索引
            batch_index.append(self.sampler[self.index])
            self.index+=1
            i+=1
        #得到batch_size个索引之后,调用dataset对象
        data=[self.dataset[idx] for idx in batch_index]
        #调用collate_fn 在batch维度进行拼接输出
        batch_data=self.collate_fn(data)
        return batch_data

    def __iter__(self):
        return self
# torch.stack()是指将列表里面的张量进行扩维拼接
# data=[[torch.Tensor([1]),torch.Tensor([1])],[torch.Tensor([1]),torch.Tensor([1])]]
# print(torch.stack(data[0],0),torch.stack(data[1],0))
# data=[torch.Tensor([1,2,3]),torch.Tensor([4,5,6])]
# print(torch.stack(data))

以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。

首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器。

Dataset通过__getitem__方法使其可迭代

Sample对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 __iter__ 内部返回迭代器,RandomSampler 在 __iter__ 内部通过 yield 关键字返回迭代器

Batchsampler也是在__iter__内部通过yield关键字返回迭代器

DataLoader通过__iter__和__next__直接实现迭代器

除了DataLoader本身是迭代器外,其余对象本身都不是迭代器,但可以for in迭代

由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。

需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。

三、整体框架的讲解

核心运行逻辑:

def __next__(self):
    #返回batch个索引
    index=next(self.batch_sampler)
    #利用索引去取数据
    data=[self.dataset[idx] for idx in index]
    #batch维度聚合
    data=self.collate_fn(data)
    return data

整体流程:

1.self.batch_sampler=iter(batch_sampler)。在DataLoader的类初始化,需要得到BatchSampler的迭代器对象。

2.index=next(self.batch_sampler)。对于每次迭代,DataLoader对象首先会调用BatchSampler的迭代器进行下一次迭代,具体是调用BatchSampler对象的__iter__方法

3.而BatchSampler对象的__iter__方法实际上是需要依靠Sampler对象进行迭代输出索引,Sampler对象也是一个迭代器,当迭代batch_size次后就可以得到batch_size个数据索引。

4.data=[self.dataset[idx] for idx in index]。有了batch个索引就可以通过不断调用dataset的__getitem__方法返回数据对象,此时data就包含了batch个对象。

5.data=self.collate_fn(data)。将batch个对象输入给聚合函数,在第0个维度也就是batch维度进行聚合,得到类似(batch,....)的对象。

6.重复上面的操作,就可以不断输出一个一个的batch数据

class Dataset(object):
    #只要实现了__getitem__方法就可以变成迭代器
    def __getitem__(self,index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
class Sampler(object):
    def __init__(self,data_source):
        pass
    def __iter__(self):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
#一般出现raise NotImplementedError这个错误,就是子类没有重写父类中的成员函数,然后子类对象调用此函数会报这个错误

class SequentialSampler(sampler):
    def __init__(self,data_source):
        super(SequentialSampler,self).__init__(data_source)
        self.data_source=data_source
    def __iter__(self):
        #返回迭代器,不然无法for  in
        return iter(range(len(self.data_source))
    def __len__(self):
        return len(self.data_source)

class BatchSampler(Sampler):
    def __init__(self,sampler,batch_size,drop_last):
        self.sampler=sampler
        self.batch_size=batch_size
        self.dorp_last=drop_last

    def __iter__(self):
        batch=[]
        for idx in self.sampler:
            batch.append(idx)
            #如果得到了batch个索引,则可以通过yield关键字生成生成器返回,得到迭代器对象
            if len(batch)==self.batch_size:
                yield batch
                batch=[]
        if len(batch)>0 and not self.drop_last:
            yield batch
    def __len__(self):
        if self.drop_last:
            #如果最后的索引数不等于一个batch,抛弃
            return len(self.sampler)//self.batch_size
        else:
            return (len(self.sampler)+self.batch_size-1)//self.batch_size
class DataLoader(object):
    def __init__(self,dataset,batch_size=1,shuffle=False,sample=None,batch_sampler=None,
                    collate_fn=None,drop_last=False):
        self.dataset=dataset
        #因为这两个功能是冲突的
        if sampler is not None and shuffle:
            raise ValueError('sampler option is ..')
        if batch_sampler is not None:
            # 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
            # 和 drop_last 四个参数就不能传入
            # 因为这4个参数功能和 batch_sampler 功能冲突了
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            batch_size = None
            drop_last = False
        if sampler is None:
            if shuffle:
                sampler = RandomSampler(dataset)
            else:
                sampler = SequentialSampler(dataset)
        # 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
        if batch_sampler is None:
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = iter(batch_sampler)
        
        if collate_fn is None:
            collate_fn = default_collate
        self.collate_fn = collate_fn

    #核心代码
    def __next__(self):
        index=next(self.batch_sampler)
        data=[self.dataset[idx] for idx in index]
        data=self.collate_fn(data)
        return data
    #返回自身,因为自身实现了next
    def __iter__(self):
        return self
        
def default_collate(batch):
    elem=batch[0]
    elem_type=type(elem)
    if isinstance(elem,torch.Tensor):
        return torch.stack(batch,0)
    elif elem_type.__module__=='numpy':
        return default_collate([torch.as_tensor(b) for b in batch])
    else:
        raise NotImplementedError

完整调用例子

class Simplev1Dataset(Dataset):
    def __init__(self):
        #伪造数据
        self.imgs=np.arange(0,16).reshape(8,2)

    def __getitem__(self,index):
        return self.imgs[index]

    def __len__(self):
        return self.imgs.shape[0]

from simplev1_dataset import Simplev1Dataset
simple_dataset=Simplev1Dataset()
dataloader=DataLoader(simple_dataset,batch_size=2,collate_fn=default_collate)
for data in dataloader:
    print(data)

四、Reference

https://zhuanlan.zhihu.com/p/340465632

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CVplayer111

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值