Pytorch数据加载模块:Dataset,Sampler和DataLoader总结

官网教程示例:

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

Pytorch加载数据三步走:


  1. Dataset:解析单个样本,把数据映射成(x,y)的形式;

    • map-style:实现__getitem__和__len__接口,随机取数据代价小(大多数情况用map-stype);
    • iterable-style:实现__iter__接口,随机取数据代价大,适合处理流数据(比如文本流数据);
  2. Sampler:提供一种遍历数据集所有元素索引的方式,有默认值;

  3. DataLoader:将当个样本变成训练时需要的batch形式;


1 Dataset

1.1 源码

# 接口
from torch.utils.data import Dataset

# 源码位置
# ../torch/utils/data/dataset.py

# 查看torch安装位置
import torch
print(torch.__file__)

源码

# Dataset抽象类 对外暴露一些接口

# map-style
class Dataset(Generic[T_co]):

    def __getitem__(self, index) -> T_co:
    	# 基类中没有实现 需要自己实现
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
# iter-style
class IterableDataset(Dataset[T_co]):

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])

1.2 创建自己的Dataset

定义自己的Dataset,继承Dataset类后,需要(必须)实现三个方法:

  • _init_
  • _len_
  • _getitem_

示例:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        # 保存图像的根路径
        self.img_dir = img_dir
        # 对数据的处理 数据增强之类的
        self.transform = transform
        # 对标签的处理
        self.target_transform = target_transform

    def __len__(self):
        # 返回一共有多少个数据
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼凑图像的完整路径
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # 读取图像
        image = read_image(img_path)
        # 从csv中读取的信息分割出标签
        label = self.img_labels.iloc[idx, 1]
        # 对数据及标签进行处理
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    
# csv保存图片名 大概长这样
# tshirt1.jpg, 0
# tshirt2.jpg, 0
# ......
# ankleboot999.jpg, 9

1.3 加载数据集

ann_csv = '../ann.csv'
img_root = '/'

# 实例化
myDataset = CustomImageDataset(ann_csv,img_root)

# 获取该类的属性
print(myDataset.img_dir)

# 获取数据的数量 可以用 但是一般不这么用
print(myDataset.__len__())

# 获取第1个数据的img和label(下标0)
# 可以用 但是一般不这么用
img,lab = myDataset.__getitem__(0)
print(img.shape, lab)

# 一般这么用...
print(len(myDataset))
img,lab = myDataset[0]
print(img.shape, lab)


# 一般不会单独用Dataset
# 扔到DataLoader里 构成batch数据

1.4 Dataset的子类

1.4.1 TensorDataset

如果数据本身已经是tensor形式了

# 数据转为tensor格式
x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)

# 直接用TensorDataset封装即可
train_dataset = TensorDataset(x_train, y_train)

1.4.2 IterableDataset

根据两个数start和end生成数据集;

# 继承IterableDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        # !核心就是根据range生成的数
        return iter(range(iter_start, iter_end))
# 实例化
# 结果:[3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# 用DataLoader 单线程进行加载
# [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))

# 用DataLoader 多线程进行加载
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

1.4.3 ConcatDataset

将多个数据集拼接成一个;
用法如下:

# 第一个数据集  len 60000
mnist_data = MNIST('./data', train=True, download=True)

# 第二个数据集 len 50000
cifar10_data = CIFAR100('./data', train=True, download=True)

# 两个数据集拼接 len 110000
concat_data = ConcatDataset([mnist_data, cifar10_data])

1.4.4 ChainDataset

将IterableDataset类的多个数据集拼接成一个数据集;

1.4.5 Subset

将一个数据集划分为子数据集,比如划分训练集和验证集;

# 训练集和验证集的索引
train_indices, val_indices = indices[split:], indices[:split]

# 根据索引随机划分
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

2 Sampler

就是遍历数据集的方式,默认方式有两种:

  • shuffle = True:sampler = RandomSampler(dataset, generator=generator),随机打乱;
  • shuffle = False:sampler = SequentialSampler(dataset),不打乱;
  • 也可以自定义Sampler传入,但是Sampler与shuffle互斥;

2.1 RandomSampler

class RandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
                 
        # slef. = ...


    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                # 核心就是torch.randperm函数
                # 生成0~n-1的随机数列(索引)
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

    def __len__(self) -> int:
        return self.num_samples

2.2 SequentialSampler

SequentialSampler其实什么也没做,不破坏数据集原有的顺序;

class SequentialSampler(Sampler[int]):

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

2.3 自定义Sampler

import random
from torch.utils.data.sampler import Sampler
 
 # 自定义必须先继承Sample类
 # 必须实现__init__,__iter__,__len__方法
class MySampler(Sampler):
    def __init__(self, dataset):
    	# 将数据集均分为两部分
        halfway_point = int(len(dataset)/2)
        self.first_half_indices = list(range(halfway_point))
        self.second_half_indices = list(range(halfway_point, len(dataset)))
        
    def __iter__(self):
    	# 每次从前一半和后一半各返回一个
    	# 假设前一半为 1 2 3 4 5 
    	#    后一半为 6 7 8 9 10
    	# 则依次返回(1,6)(2,7)(3,8)...
        random.shuffle(self.first_half_indices)
        random.shuffle(self.second_half_indices)
        return iter(self.first_half_indices + self.second_half_indices)
    
    def __len__(self):
		return len(self.first_half_indices) + len(self.second_half_indices)

3 DataLoader

3.1 使用DataLoader

from torch.utils.data import DataLoader

training_data = myDataset(...)
test_data = myDataset(...)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# 一般测试集不打乱 没有意义
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

train_features, train_labels = next(iter(train_dataloader))
# Feature batch shape: torch.Size([64, 1, 28, 28])
print(f"Feature batch shape: {train_features.size()}")
# Labels batch shape: torch.Size([64])
print(f"Labels batch shape: {train_labels.size()}")

# ... 一些处理

3.2 源码及参数

# DataLoader源码位置
# /torch/utils/data/dataloader.py

# 参数们
# dataset: Dataset实例对象
# batch_size:批量大小 默认为1
# shuffle:每周期后是否对数据进行打乱
# sampler:遍历数据集的方式 有默认值 和shuffle互斥
# batch_sampler:同上 和shuffle sampler drop_last batch_size互斥
# num_workers:默认为0 加载数据(batch)的进程数目
# num_workers的经验设置值是自己电脑/服务器的CPU核心数
# 0意味着所有的数据都会被load进主进程
# collate_fn: 对batch数据再处理
# pin_memory: 锁页内存 数据放到GPU上
# drop_last: 非整数batch时 最后一个batch丢掉
# timeout: 如果是正数,表明等待从worker进程中收集一个batch等待的时间
# 若超出设定的时间还没有收集到,那就不收集这个内容了
class DataLoader(Generic[T_co]):
    dataset: Dataset[T_co]
    batch_size: Optional[int]
    num_workers: int
    pin_memory: bool
    drop_last: bool
    timeout: float
    sampler: Union[Sampler, Iterable]
    prefetch_factor: int
    _iterator : Optional['_BaseDataLoaderIter']
    __initialized = False
   

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
                 batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
 		
        # 一堆成员变量设置
		# self. = ...
	
	# sampler的设置
    if sampler is None:  
            if self._dataset_kind == _DatasetKind.Iterable:
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    # 原理:通过torch.randperm实现 打乱
                    sampler = RandomSampler(dataset, generator=generator)  
                else:
                    # 原理:iter(range()) 有序
                    sampler = SequentialSampler(dataset)  
                   
	# 在__iter__调用
	# 复写基类方法 实现iter函数
    # 可以调用为iter(train_dataloader)
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            # 获取下一个索引 根据索引获得并返回数据
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            # 多进程进行处理
            return _MultiProcessingDataLoaderIter(self)

        
  
	# 变成迭代器
    def __iter__(self) -> '_BaseDataLoaderIter':
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()


	
    # 在_BaseDataLoaderIter类调用
    # 其实就是复写基类方法,实现next函数
    # next(iter(train_dataloader))
    @property
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler
        

    # 返回有多少batch
    def __len__(self) -> int:
        # ...

    # 对num_workers设定合理性进行检查
    def check_worker_number_rationality(self):
        # ...
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch是一个开源的机器学习库,内置丰富的函数和工具包用于数据加载数据预处理、模型构建、训练和评估。数据集是机器学习模型的重要组成部分,PyTorch提供了多种方法用于加载数据集,包括内置的函数和可定制的方法,让用户可以根据自己的需求和数据集特性来选择适合的方式。 内置函数 PyTorch提供了内置的函数用于加载常见的数据集,如MNIST、CIFAR-10、ImageNet等。这些函数通常包括下载数据集、转换为Tensor格式、划分为训练集和测试集等步骤,使用简单方便,适合快速上手使用。 可定制方法 如果内置函数不能满足需求,PyTorch也提供了许多可定制的方法。最常用的是DatasetDataLoader类。Dataset类是抽象类,用户需要继承这个类并实现getitem和len方法来定义自己的数据集。DataLoader类用于生成一个迭代器,用户可以设置批量大小、并行加载、随机采样等参数。 除此之外,PyTorch还提供了其它一些用于数据集处理的工具,如transforms模块Sampler类、collate_fn函数等,可以用于数据增强、数据集分块和数据集拼接等场景。 总结 PyTorch提供了内置函数和可定制方法用于加载数据集,用户可以根据自己的需求和数据集特性来选择适合的方式。使用内置函数可以快速上手,使用可定制方法可以更加灵活和高效。对于多样化的数据集,PyTorch还提供了多个处理工具,可以用于数据增强、数据集分块和数据集拼接等场景。 ### 回答2: PyTorch是一种基于Python的开源机器学习库,它可以用于构建各种机器学习模型。在PyTorch中,数据加载是一个非常重要的部分,因为机器学习模型需要大量的数据来进行训练。 在PyTorch中,数据加载可以通过DataLoader类来实现。DataLoader是一个Python迭代器,它可以加载大量的数据集,并将其分成小批量进行训练。这样可以避免一次性将整个数据加载到内存中,从而节省内存空间。 首先,我们需要将数据加载到内存或磁盘中,并将其转换为PyTorch数据集类的对象。PyTorch提供了两种数据集类:Dataset和IterableDataset。其中,Dataset类是一种基于索引的数据集类,它可以通过索引来访问数据集中的每个数据样本;而IterableDataset是一种基于迭代器的数据集类,它可以像Python中的迭代器一样使用。 然后,我们可以使用DataLoader类来加载数据集。DataLoader类有很多参数,包括batch_size(表示每个小批量包含的样本数)、shuffle(表示是否随机打乱数据集顺序)、num_workers(表示使用多少个工作线程来加载数据集)等。 在使用DataLoader加载数据集时,我们可以通过for循环来迭代数据集中的每个小批量,并将其传递给机器学习模型进行训练。 总之,PyTorch数据加载是非常灵活和易于使用的。通过使用DataLoader类和PyTorch提供的数据集类,我们可以轻松地加载和处理大量的数据集,并将其用于训练各种机器学习模型。 ### 回答3: Pytorch是一个使用Python作为开发语言的深度学习框架,提供了非常强大的数据加载和预处理工具。在Pytorch中,数据加载主要通过两个类来实现,分别是DatasetDataLoaderDataset类负责加载和处理数据集,而DataLoader类则负责将处理后的数据安装指定的batch_size分批加载到内存中,避免了内存不足的问题。 Dataset类是一个抽象类,需要根据具体的数据集来实现其中的方法。一般而言,Dataset类中需要实现__len__()方法和__getitem__()方法,分别用来获取数据集的长度和获取具体的数据样本。 DataLoader类则负责对数据集进行batch处理,这样可以充分利用系统的存储和计算资源,提高了模型的训练速度。在使用DataLoader时,需要指定batch_size、shuffle和num_workers等参数,其中num_workers可以指定使用多少个进程来装载数据,从而进一步提高了数据装载的效率。 在数据加载过程中,可以使用Pytorch提供的transforms模块来进行数据的预处理,如改变图像尺寸、随机翻转、归一化等操作,从而增加模型的泛化能力和准确性。 总之,Pytorch数据加载和预处理方面提供了非常强大的工具,只需要实现一些简单的代码,就能轻松地完成数据加载和预处理,从而为模型的训练和测试提供了坚实的基础。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值