pytorch笔记:Dataset 和 DataLoader

来自B站视频官网教程API查阅

  • A custom Dataset class must implement three functions: __init__, __len__, and __getitem__.
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

DataLoader 介绍,源码:

  • collate_fn 是针对 minibatches 的操作,Dataset 的 transform 是针对单个样本的处理
  • collate_fn 的参数是 batch,即 batch_size 个 __getitem__ 返回的 item
  • 一般的 Dataset 类型都是map-style datasets,如果是 iterable-style 的话,迭代完之后就会变成空的

sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any Iterable with __len__
implemented. If specified, :attr:shuffle must not be specified.

batch_sampler (Sampler or Iterable, optional): like :attr:sampler, but
returns a batch of indices at a time. Mutually exclusive with
:attr:batch_size, :attr:shuffle, :attr:sampler,
and :attr:drop_last.

  • 不设置 sampler 参数,会有默认的 sampler 处理
 if sampler is None:  # give default samplers
     if self._dataset_kind == _DatasetKind.Iterable:
         # See NOTE [ Custom Samplers and IterableDataset ]
         sampler = _InfiniteConstantSampler()
     else:  # map-style
         if shuffle:
             sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
         else:
             sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

SequentialSampler 原序返回

 return iter(range(len(self.data_source)))
  • 不设置 batch_sampler 参数,会有默认的 batch_sampler 处理,它根据 sampler 采样组成一个 batch 后返回
if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  • 不设置 collate_fn 参数,一般也没有 batch_sampler 的情况下,调用默认的 default_collate,以 batch 为参数,基本没有做任何事
@property
def _auto_collation(self):
    return self.batch_sampler is not None
    
if collate_fn is None:
    if self._auto_collation:
        collate_fn = _utils.collate.default_collate
    else:
        collate_fn = _utils.collate.default_convert
  • 视频里讲了_index_sampler 和 _get_iterator 的相关内容
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_森罗万象

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值