2.PyTorch的Dataset和DataLoader

#! https://zhuanlan.zhihu.com/p/543481537

2.PyTorch的Dataset和DataLoader

本次笔记记录如何构建数据集,如何搭建minibatch,以及DataLoader源码剖析.

Datasets & DataLoader

  1. torch.utils.data.Dataset:
    处理单个训练样本,从磁盘读取训练数据(特征、标签),预处理变成{x,y}训练对。
  2. torch.utils.data.DataLoader:
    得到单个训练样本后,变成minibatch的形式(将数据打乱,保存在gpu中等)
#使用Torchvision导入内置数据集
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

构建自己的数据集

继承抽象类Dataset,并自己实现以下三个函数:

  1. __init__
    初始化操作(文件路径、transform······)
  2. __len__
    在DataLoader中需要用到Dataset的长度,在Init中读取数据,直接用len()返回即可
  3. __getitem__
    对Dataset索引返回单个训练样本,返回x,y训练对(有标签时)
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])#每张图片路径
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)#预处理
        if self.target_transform:
            label = self.target_transform(label)#预处理label
        return image, label

这种dataset叫做map-style datasets
读流式数据可以用iterable-style dataset

使用DataLoaders为训练准备数据

使用Minibatch的形式训练数据,在每个epoch都reshuffle数据从而降低模型过拟合的可能性,还通过python的multiprocessing多进程加载数据,使得读数据的过程不影响gpu训练,实现低延迟。(有的数据集读取的过程贼慢,比方说TextVQA这种,很影响GPU训练的效率)
DataLoader返回的是一个列表。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#test不需要梯度下降,只要前向传递,不需要shuffle
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

展示一张图片与标签,

下面第一行代码的next,iter方法后面解释

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()#灰度图片要变成三通道
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

DataLoader深入剖析

官网的教程过于简单,下面深入剖析DataLoader,
包括sampler,collate_fn处理等,还有一些排序方法(比如让一个minibatch中样本长度差不多,这样在进行padding等操作的时候就会简单很多)。
主要解析三部分代码:

  1. torch/utils/data/dataloader.py
  2. torch/utils/data/sampler.py
  3. torch/utils/data/_utils/collate.py

首先看dataloader

'''
dataset: Dataset[T_co]:传入一个实例化的dataset对象
shuffle:每个epoch后打乱数据,一般在训练集上用
sampler/batch_sampler:自定义采样方式,显然与shuffle冲突
num_workers:进程数量,多进程读数据,0代表只用一个主进程
pin_memory:把tensor保存在gpu上,不用重复保存,对于效率影响有待考究
drop_last:样本数量不是batch_size整数倍时,把最后一部分数据丢掉
collate_fn:自定义聚集函数,对小批次数据进行再次处理,输入一个Batch,输出一个batch
timeout:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
'''
 def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                 batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
        torch._C._log_api_usage_once("python.data_loader")

        if num_workers < 0:
            raise ValueError('num_workers option should be non-negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if num_workers == 0 and prefetch_factor != 2:
            raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
                             'let num_workers > 0 to enable multiprocessing.')
        assert prefetch_factor > 0

        if persistent_workers and num_workers == 0:
            raise ValueError('persistent_workers option needs num_workers > 0')
        #设置成员变量
        self.dataset = dataset
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor
        self.pin_memory = pin_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.multiprocessing_context = multiprocessing_context
        #后续代码太长,这里不一一展示,后面不算很难

后续代码太长,这里不一一展示,init函数主要做三件事:

  1. 构建sampler
  2. 构建batchsampler
  3. 构建collate_fn

这里说几点重要的,方便放矢地看源码:

1.构建sampler

1.1.shuffle内部通过RandomSampler的具体实现,重点看__iter__.

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            generator = torch.Generator()
            generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
        else:
            generator = self.generator
        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            yield from torch.randperm(n, generator=generator).tolist()

1.2.shuffle = None通过SequentialSampler实现.

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

2.我们不使用batchsampler而仅仅是设定batch_size,内部其实还是通过BatchSampler实现的.

   if batch_size is not None and batch_sampler is None:
        # auto_collation without custom batch_sampler
        batch_sampler = BatchSampler(sampler, batch_size, drop_last)

这里插入说一下python的迭代器和生成器,方便下面的理解

'''
迭代是Python最强大的功能之一,是访问集合元素的一种方式。

迭代器是一个可以记住遍历的位置的对象。

迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。

迭代器有两个基本的方法:iter() 和 next()。

字符串,列表或元组对象都可用于创建迭代器:
'''
>>> list=[1,2,3,4]
>>> it = iter(list)    # 创建迭代器对象
>>> print (next(it))   # 输出迭代器的下一个元素
1
>>> print (next(it))
2
#迭代器对象可以使用常规for语句进行遍历:
list=[1,2,3,4]
it = iter(list)    # 创建迭代器对象
for x in it:
    print (x, end=" ")

'''
在 Python 中,使用了 yield 的函数被称为生成器(generator)。

跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。

在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。

调用一个生成器函数,返回的是一个迭代器对象。

以下实例使用 yield 实现斐波那契数列:
'''
import sys
 
def fibonacci(n): # 生成器函数 - 斐波那契
    a, b, counter = 0, 1, 0
    while True:
        if (counter > n): 
            return
        yield a
        a, b = b, a + b
        counter += 1
f = fibonacci(10) # f 是一个迭代器,由生成器返回生成
 
while True:
    try:
        print (next(f), end=" ")
    except StopIteration:
        sys.exit()

回归BatchSampler实现:

def __iter__(self) -> Iterator[List[int]]:
    batch = []
    for idx in self.sampler:
        batch.append(idx)
            f len(batch) == self.batch_size:
            yield batch #生成器,每次next返回一个batch
            batch = []
    if len(batch) > 0 and not self.drop_last:#drop_last的原理
        yield batch

3.collate_fn自定义聚集函数

if collate_fn is None:
    if self._auto_collation: #batch_sample is not None ->true
        collate_fn = _utils.collate.default_collate #输入输出都是batch,做了一些数据类型的转换
    else:
        collate_fn = _utils.collate.default_convert

self.collate_fn = collate_fn

4.迭代器iter()可以把dataloader变成一个迭代器
iter -> _get_iterator()
_get_iterator() -> _SingleProcessDataLoaderIter() or _MultiProcessingDataLoaderIter()

_SingleProcessDataLoaderIter()继承自BaseDataLoaderIter.

我们发现_SingleProcessDataLoaderIter()_next_data函数没有在子类中被调用,猜测可以在父类BaseDataLoaderIter找到调用.

经查看,父类__next__中确实调用了_next_data,这就可以解释了代码的内部原理。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值