Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解

本文介绍了PyTorch中的DataLoader数据类型,详细解析了构造函数中的关键参数,如dataset、batch_size、shuffle、collate_fn等,并探讨了它们在训练模型时的作用。DataLoader用于定义数据加载方式,包括数据批量读取、洗牌、多线程导入等,有助于提高训练效率。
摘要由CSDN通过智能技术生成

DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式。详情也可参考本博主的另一篇关于torch.utils.data.DataLoaderhttps://blog.csdn.net/qq_36653505/article/details/83351808)的讨论。

在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?

下面就研究一下:

先看看 dataloader.py源码是怎么写的(VS中按F12跳转到该脚本)

__init__(构造函数)中的几个重要的属性:

1、dataset:(数据类型 dataset)

输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。

2、batch_size:(数据类型 int)

每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。

3、shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

4、collate_fn:(数据类型 callable,没见过的类型)

将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)

5、b

  • 15
    点赞
  • 100
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值