由浅入深:终于搞懂了 Python 和 PyTorch迭代器(iterator)、Dataset 和 DataLoader

一、从一个简单的列表引入迭代器

关于 python 迭代器的基本介绍和使用可以看我之前写的博客: Python迭代器的创建和使用:iter()和next()方法,迭代器长度的获取

将一个列表转换为迭代器(用 iter 方法),并逐个元素打印。

print(list(range(5)))    # [0, 1, 2, 3, 4]
myiter = iter(list(range(5)))    # 用iter方法将list转换为iterator
for x in myiter:
    print(x)
    
# 输出:
0
1
2
3
4

二、用类作为迭代器

现在我们把事情搞复杂一点,用类来产生一个相同的列表,并逐个元素打印。

要想让一个类作为迭代器,就要用到 __ iter __() 方法,python中实现了__iter __() 方法的对象是可迭代的,也就是一个迭代器。(对象就是类的一个实例)

__ iter __() 函数是python的魔术方法,这个函数的要求是返回值必须是一个迭代器。该方法使得类成为一个迭代器。

现在我们用类作为迭代器实现 0-4 整型迭代:

class MyClass:
    def __init__(self, num):
        self.num = num

    # 得到相应的列表
    def numlist(self):
        return list(range(self.num))
	
	# 得到迭代器
    def __iter__(self):
        return iter(self.numlist())   # 用iter方法将list转换为iterator

myiter = MyClass(num=5)   # 实例化一个对象myiter,myiter是一个迭代器
for x in myiter:
    print(x)

# 输出:
0
1
2
3
4

三、将类的功能变得更复杂

如果我们想以列表形式每次输出5个数,输出范围是0-20,比如这样:

[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]
[15, 16, 17, 18, 19]

那么第二步中的类就实现不了这样的功能,下面我们将这个类的功能进行扩展,使它变得更复杂。

__ iter__() 方法每次只能返回list中的一个元素,实现不了返回一组元素的功能。所以我们要手动分割数组来完成该功能,这里引入了一个新的方法:__ next __()

__ next __() 方法的功能是返回迭代器的下一个元素

我们有了 __ next __ () 之后,就不需要在 __ iter __ () 中返回一个列表迭代器了,因为这个功能一般由__ next __() 完成。现在引入一个概念:

把一个类作为一个迭代器使用需要在类中实现两个方法 iter() 与 next()
__ iter __() 的返回值是 self

也就是说这种情况下,__ iter__() 一般写成下面这种格式,复杂的功能实现交给 __ next __() 完成。

 def __iter__(self):
        return self
        # return self: 表示返回一个类的对象实例,也可以理解为返回自己。这个对象可以被链式调用。

return self 可以理解为 我返回我自己,相当于这个类自己在递归,不断迭代自己(自嗨),那不就是一个迭代器了吗?

现在我们来实现每次输出5个数的功能:

class MyClass:
    def __init__(self, num, step, start=0):
        self.num = num
        self.step = step
        self.start = start

    # 用于产生列表
    def numlist(self):
        return list(range(self.num))

    def __iter__(self):
        return self

    def __next__(self):
        numlist = self.numlist()
        if self.start < len(numlist):
            numsplit = numlist[self.start:(self.start + self.step)]
            self.start += self.step
            return numsplit
        else:
            raise StopIteration


myiter = MyClass(num=20, step=5)
for x in myiter:
    print(x)

# 输出:
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]
[15, 16, 17, 18, 19]

现在我们就用自己定义的类实现了迭代器的功能,这有助于我们理解和完成更复杂的功能,毕竟在一个大的项目里,用迭代器处理数据基本都需要用类来实现。

四、pytorch 数据处理的迭代器:Dataset 和 DataLoader

掌握了python迭代器的基本使用之后,我们再来看看更复杂的pytorch的迭代器。pytorch进行数据处理必然离不开 Dataset 和 DataLoader, Dataset 用于产生迭代器, DataLoader加载迭代器产生可以用enumerate 迭代控制的 target和 label。

from torch.utils.data import Dataset,DataLoader

这里我们用MNIST数据集的测试集来讲解Dataset 和 DataLoader的使用。

代码中涉及到MNIST数据集的处理请参考博客 MNIST手写数字数据集读取方法

4.1 用自己定义的 MnistDataset 类,不继承 torch 的 Dataset

首先我们自己写一个MnistDataset 类用于数据集处理和加载,不继承 torch 的 Dataset 。这里用到了一个新的方法 __ getitem __(self, index) ,其中index表示索引(即下标)。

__ getitem __() 的作用是让类拥有迭代功能,它与 __ iter __() 的不同之处在于: __ iter __() 的返回值必须是迭代器,而 __ getitem __() 的返回值没有限制。

只要类中有 __ getitem __() 方法,这个类的对象就是迭代器。

import numpy as np
import struct

class MnistDataset:
    def __init__(self, images_file, labels_file):
        self.images_file = images_file
        self.labels_file = labels_file

    # 将所有图片以numpy格式存放在列表中
    def load_imags(self, file):
        bin_data = open(file, 'rb').read()  # 读取二进制数据

        # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
        offset = 0
        fmt_header = '>iiii'
        _, image_num, image_rows, image_cols = struct.unpack_from(fmt_header, bin_data, offset)

        # 解析数据集
        image_size = image_rows * image_cols
        offset += struct.calcsize(fmt_header)
        fmt_image = '>' + str(image_size) + 'B'
        images = np.empty((image_num, image_rows, image_cols))

        for i in range(image_num):
            images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((image_rows, image_cols))
            offset += struct.calcsize(fmt_image)
        return images

    # 将所有数字label存放在列表中
    def load_labels(self, file):
        bin_data = open(file, 'rb').read()  # 读取二进制数据

        # 解析文件头信息,依次为魔数和标签数
        offset = 0
        fmt_header = '>ii'
        _, image_num = struct.unpack_from(fmt_header, bin_data, offset)

        # 解析数据集
        offset += struct.calcsize(fmt_header)
        fmt_image = '>B'
        labels = np.empty(image_num)
        for i in range(image_num):
            labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
            offset += struct.calcsize(fmt_image)
        return labels

    def __getitem__(self, index):
        images = self.load_imags(self.images_file)
        labels = self.load_labels(self.labels_file)
        return images[index], labels[index]


if __name__ == '__main__':
    images_file = 'MNIST_data/t10k-images.idx3-ubyte'
    labels_file = 'MNIST_data/t10k-labels.idx1-ubyte'
    dataset = MnistDataset(images_file, labels_file)

    for id, (image,label) in enumerate(dataset):
        print(label)
        
# 输出:
7.0
2.0
1.0
0.0
4.0
1.0
4.0
9.0
5.0
9.0
0.0
6.0
9.0
0.0
1.0
...

这里我们直接对 dataset 进行迭代,可以发现每一次迭代都会输出一个 label,并且这个label是数字,而不是list或者tensor格式。

现在我们用 DataLoader 对dataset进行加载,改变代码如下:

from torch.utils.data import DataLoader
if __name__ == '__main__':
    images_file = 'MNIST_data/t10k-images.idx3-ubyte'
    labels_file = 'MNIST_data/t10k-labels.idx1-ubyte'
    dataset = MnistDataset(images_file, labels_file)
    dataloader = DataLoader(dataset, batch_size=4)
    for id, (image, label) in enumerate(dataloader):
        print(label)

# 输出:
tensor([7., 2., 1., 0.], dtype=torch.float64)
tensor([4., 1., 4., 9.], dtype=torch.float64)
tensor([5., 9., 0., 6.], dtype=torch.float64)
tensor([9., 0., 1., 5.], dtype=torch.float64)
tensor([9., 7., 3., 4.], dtype=torch.float64)
...

可以发现,DataLoader 可以正常进行加载,并且可以设置batch_size的大小,输出的label是 tensor 格式。

所以用进行数据时并不一定需要继承 torch 的 Dataset ,自己写一个相同功能的类也可以。

那么问题来了,既然 dataset 本身就可以迭代,为什么还需要 DataLoader 呢?答案当然是用DataLoader 可以设置 batch_size、shuffle 等设置,实现更灵活的数据集加载方式。

4.2 继承 torch 的 Dataset 类进行数据处理

我们只需要对4.1的代码稍作修改就可以继承Dataset 类了:

import numpy as np
import struct
from torch.utils.data import Dataset, DataLoader

class MnistDataset(Dataset):    # 改动的地方
    def __init__(self, images_file, labels_file):
        super(Dataset).__init__()     # 改动的地方
        self.images_file = images_file
        self.labels_file = labels_file

    # 将所有图片以numpy格式存放在列表中
    def load_imags(self, file):
        bin_data = open(file, 'rb').read()  # 读取二进制数据

        # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
        offset = 0
        fmt_header = '>iiii'
        _, image_num, image_rows, image_cols = struct.unpack_from(fmt_header, bin_data, offset)

        # 解析数据集
        image_size = image_rows * image_cols
        offset += struct.calcsize(fmt_header)
        fmt_image = '>' + str(image_size) + 'B'
        images = np.empty((image_num, image_rows, image_cols))

        for i in range(image_num):
            images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((image_rows, image_cols))
            offset += struct.calcsize(fmt_image)
        return images

    # 将所有数字label存放在列表中
    def load_labels(self, file):
        bin_data = open(file, 'rb').read()  # 读取二进制数据

        # 解析文件头信息,依次为魔数和标签数
        offset = 0
        fmt_header = '>ii'
        _, image_num = struct.unpack_from(fmt_header, bin_data, offset)

        # 解析数据集
        offset += struct.calcsize(fmt_header)
        fmt_image = '>B'
        labels = np.empty(image_num)
        for i in range(image_num):
            labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
            offset += struct.calcsize(fmt_image)
        return labels

    def __getitem__(self, index):
        images = self.load_imags(self.images_file)
        labels = self.load_labels(self.labels_file)
        return images[index], labels[index]


if __name__ == '__main__':
    images_file = 'MNIST_data/t10k-images.idx3-ubyte'
    labels_file = 'MNIST_data/t10k-labels.idx1-ubyte'
    dataset = MnistDataset(images_file, labels_file)
    dataloader = DataLoader(dataset, batch_size=4)
    for id, (image, label) in enumerate(dataloader):
        print(label)

# 输出:
Traceback (most recent call last):
  File "F:\miniconda3\lib\site-packages\torch\utils\data\sampler.py", line 66, in __iter__
    return iter(range(len(self.data_source)))
TypeError: object of type 'MnistDataset' has no len()

Process finished with exit code 1

发现程序报错了,说MnistDataset类没有 len() 方法。这里我们来看一下__ len__() 方法,它的作用是返回容器中元素的个数,这里就是指返回 MNIST 数据集中图片的数量。

为什么一定需要__ len__() 方法呢?4.1中不继承Dataset 时候没有写__ len__() 方法不是一样可以加载吗?这就是pytorch的严谨之处了,没有这个方法,程序就不知道有多少数据量,用 enumerate 迭代时怎么知道到哪里停止呢?

我们再看看 pytorch 官网对 Dataset 的解释:
在这里插入图片描述
可以看到,子类必须重写__getitem__(),可以选择性覆盖__len__(),许多 Sampler 实现和 DataLoader的默认选项期望它返回数据集的大小。

所以继承了 Dataset,用 DataLoader 加载时,必须要有__len__() 方法。

那我们现在给 MnistDataset(Dataset) 类增加__len__() 方法:

import numpy as np
import struct
from torch.utils.data import Dataset, DataLoader

class MnistDataset(Dataset):
    def __init__(self, images_file, labels_file):
        super(Dataset).__init__()
        self.images_file = images_file
        self.labels_file = labels_file

    # 将所有图片以numpy格式存放在列表中
    def load_imags(self, file):
        bin_data = open(file, 'rb').read()  # 读取二进制数据

        # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
        offset = 0
        fmt_header = '>iiii'
        _, image_num, image_rows, image_cols = struct.unpack_from(fmt_header, bin_data, offset)

        # 解析数据集
        image_size = image_rows * image_cols
        offset += struct.calcsize(fmt_header)
        fmt_image = '>' + str(image_size) + 'B'
        images = np.empty((image_num, image_rows, image_cols))

        for i in range(image_num):
            images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((image_rows, image_cols))
            offset += struct.calcsize(fmt_image)
        return images

    # 将所有数字label存放在列表中
    def load_labels(self, file):
        bin_data = open(file, 'rb').read()  # 读取二进制数据

        # 解析文件头信息,依次为魔数和标签数
        offset = 0
        fmt_header = '>ii'
        _, image_num = struct.unpack_from(fmt_header, bin_data, offset)

        # 解析数据集
        offset += struct.calcsize(fmt_header)
        fmt_image = '>B'
        labels = np.empty(image_num)
        for i in range(image_num):
            labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
            offset += struct.calcsize(fmt_image)
        return labels

    def __getitem__(self, index):
        images = self.load_imags(self.images_file)
        labels = self.load_labels(self.labels_file)
        return images[index], labels[index]

    def __len__(self):
        images = self.load_imags(self.images_file)
        return len(images)


if __name__ == '__main__':
    images_file = 'MNIST_data/t10k-images.idx3-ubyte'
    labels_file = 'MNIST_data/t10k-labels.idx1-ubyte'
    dataset = MnistDataset(images_file, labels_file)
    dataloader = DataLoader(dataset, batch_size=4)
    for id, (image, label) in enumerate(dataloader):
        print(label)

# 输出:
tensor([7., 2., 1., 0.], dtype=torch.float64)
tensor([4., 1., 4., 9.], dtype=torch.float64)
tensor([5., 9., 0., 6.], dtype=torch.float64)
tensor([9., 0., 1., 5.], dtype=torch.float64)
tensor([9., 7., 3., 4.], dtype=torch.float64)
...

现在就可以正常运行了。如果设置 shuffle=True,也没有任何问题:

if __name__ == '__main__':
    images_file = 'MNIST_data/t10k-images.idx3-ubyte'
    labels_file = 'MNIST_data/t10k-labels.idx1-ubyte'
    dataset = MnistDataset(images_file, labels_file)
    dataloader = DataLoader(dataset, batch_size=4,shuffle=True)
    for id, (image, label) in enumerate(dataloader):
        print(label)

# 输出:
tensor([8., 5., 4., 3.], dtype=torch.float64)
tensor([2., 2., 5., 2.], dtype=torch.float64)
tensor([4., 8., 1., 8.], dtype=torch.float64)
tensor([6., 7., 5., 9.], dtype=torch.float64)
tensor([3., 7., 3., 8.], dtype=torch.float64)

以上就是基本的迭代器使用方法,对于迭代器我还有很多不理解的地方,所以这篇博客也会不断完善。

  • 19
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ctrl A_ctrl C_ctrl V

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值