DataLoader是PyTorch中的一种数据类型。
在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?
下面就研究一下:
先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本)
__init__(构造函数)
中的几个重要的属性:
1、dataset:(数据类型 dataset)
输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。
2、batch_size:(数据类型 int)
每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。
3、shuffle:(数据类型 bool)
洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
4、co