使用基于 PyTorch 构建的模型进行训练前,需要对数据进行加载操作
即使用 torch.utils.data.DataLoader()
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
对数据 dataset 进行加载,同时提供可分批加载功能,即设置 batch_size
dataset:需要加载的数据集
batch_size:默认1,batch的大小
shuffle:默认 False,是否在每个epoch重新打乱样本顺序
sampler:默认None,定义从数据集中获取样本的策略,设定此项则忽略 shuffle
num_workers:默认0,加载使用的进程的数量,0表示在主进程中加载数据
collate_fn:合并一组样本以形成张量的mini-batch(从map-style的数据集中分批加载数据时使用)
pin_memory

最低0.47元/天 解锁文章
543

被折叠的 条评论
为什么被折叠?



