DataLoader()函数的参数说明:
1.dataset (必需): 用于加载数据的数据集,通常是torch.utils.data.Dataset子类的实例化;
2.batch_size (可选): 每个批次的数据样本数,default: 1
3.shuffle (可选): 是否打乱数据集,default: False
4.num_workers (可选): 数据加载的子进程数量,default: 0,即数据只在主进程中加载;
5.drop_last (可选): 如果数据集样本总数不能被批次大小整除,是否丢弃最后一个不完整的批次,default: False
6.pin_memory (可选): 当pin_memory设置为True时,数据加载器会将数据加载到的固定页(锁页内存)中,而GPU可以直接访问固定页(锁页内存)中的数据,而不需要经过额外的数据拷贝操作,
因此cpu内存花销增大,但可以提高使用gpu训练时数据加载的效率,default: False
import torch
from torch.utils.data import DataLoader, Dataset
# 实例化
train_dataset = Dataset() # 这里简写了,一般都是自己定义Dataset子类的实例化
train_data_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=2,
drop_last=True,
pin_memory=True,
)