转载自作者marsggbo
https://www.cnblogs.com/marsggbo/p/11308889.html
在 pytorch 的体系中,数据加载的最终目的使用 Dataloader 处理 dataset 对象,以方便的控制 Batch,Shuffle 等等操作。
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=default_collate,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
初始化参数里有两种sampler:sampler和batch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index
num_worker 负责数据加载的多进程数量。
num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮…迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。
如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是先在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度更慢。
collate_fn 如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
drop_last 告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留
pin_memory pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。
worker_init_fn 子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据