深入剖析PyTorch DataLoader源码

首先再开始之前,我想问一下在座的各位知道DataLoader的原理吗,不会的扣1,会的小伙伴们扣脚指头,嘿嘿,开玩笑的,我也不知道,那么接下来我们一起学习如何写吧!

推荐教学视频:4、PyTorch的Dataset与DataLoader详细使用教程_哔哩哔哩_bilibili

5、深入剖析PyTorch DataLoader源码_哔哩哔哩_bilibili(强烈建议,因为孩子也不知道博主讲到哪了,太快了(哭))

推荐学习链接:Pytorch的Dataset和DataLoader详细使用教程_后来后来啊的博客-CSDN博客


pytorch官方示例代码:Datasets & DataLoaders — PyTorch Tutorials 2.0.1+cu117 documentation


官方给出了代码

Iterate through the DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) #创建了一个`train_dataloader`对象,用于加载训练数据。`training_data`是训练数据集,`batch_size=64`表示每个批次的样本数量为64,`shuffle=True`表示在每个epoch开始时对数据进行随机洗牌
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)#创建了一个`test_dataloader`对象,用于加载测试数据。`test_data`是测试数据集,`batch_size=64`和`shuffle=True`的含义与上述相同。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))#使用`next(iter(train_dataloader))`从`train_dataloader`中获取一个批次的训练特征和标签。`train_features`是训练特征,`train_labels`是对应的训练标签。
print(f"Feature batch shape: {train_features.size()}")#打印出训练特征的形状,`train_features.size()`返回一个元组,表示特征的维度。例如,如果特征是一个3维张量,形状为(64, 1, 28, 28),则打印结果为"Feature batch shape: torch.Size([64, 1, 28, 28])"
print(f"Labels batch shape: {train_labels.size()}")#打印出训练标签的形状,`train_labels.size()`返回一个元组,表示标签的维度。例如,如果标签是一个1维张量,形状为(64,),则打印结果为"Labels batch shape: torch.Size([64])
img = train_features[0].squeeze()#将训练特征中的第一张图片展示出来。`train_features[0]`表示取第一张图片,
label = train_labels[0]# `squeeze()`函数用于去除维度为1的维度,将形状为(1, 28, 28)的张量转换为形状为(28, 28)的张量。`plt.imshow()`函数用于显示图片,`cmap="gray"`表示以灰度图的形式显示
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")#打印出训练标签的值,即该图片对应的标签。例如,如果标签是一个整数,值为5,则打印结果为"Label: 5"

在pycharm中若想查看源码,则只需要输入import torch.utils.data.dataloader,然后按住ctrl+dataloader即可以查看源码

Dataloader相对于dataset,相当于把数据组成一批次,然后在进行处理

以下是部分源码

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler] = None,
                 batch_sampler: Optional[Sampler[Sequence]] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
        torch._C._log_api_usage_once("python.data_loader")

#shuffle,在训练结束后随机打乱
#sample:可以自定义在里面取数据,例如想让样本已一种有序的方式取样
#设置了sample和batch_sample就不需要设置shuffle了
#num_workers:多进行读取,一般以0读取
#collate_fn:聚集函数

 

 sampler作用:以某种顺序去dataset里面取元素,样本

 iter实际走的就是上述函数,返回self._ger_iteration(),即下图 

 

 上函数继承于

 next官方讲解:

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data()
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data

对next的调用就返回了miniset的调用

总结:在pytorch中需要自定义,按照官网教程自定义即可,加载文件,预处理等;在dataset中至少需要实现__init__,__len__,__getitem__三个方法(详细用法可以 查看Pytorch的Dataset和DataLoader详细使用教程_后来后来啊的博客-CSDN博客)

做完后需要放入dataloader中,设计sampler:决定在datase中取样本的 顺序,若不设置设置suffer=True,则会内置实例化一个RandomSampler(把数据库索引中打乱,随机选取),collate_fn:对已经取得进行填充,厚处理(相当于什么都没干,有些文本分类需要使用),该函数以batch作为输入,以batch作为输出;iter方法源码走346行逻辑

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值