1.这里不解释源码(因为我也看不懂)
2. 简单解释
dataloader本质是一个可迭代(iterable)对象,使用iter()访问,不能使用next()访问;
( iter和next:迭代器和生成器)
官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。(添加链接描述)
下面的代码为DataLoader的源码:
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
3.DataLoader()实例化对象的访问
既然叫做可迭代的对象,其实就可以当做 迭代器 处理
方法一:
使用iter(dataloader)返回迭代器,
再使用next()访问;
方法二:
for inputs, labels in dataloaders :
进行可迭代对象的访问;
一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;
3.构建DataLoader的pipeline
- 构建一个 Dataset 对象
- 构建一个 DataLoader 对象
- 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
class Diabetes(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# 生成一个Diabetes对象,自动调用构造函数,数据就写入了self.x_data,self.y_data中
dataset = Diabetes('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
...
...
# 下面的 i 的作用就是看一下每个epoch分成了多少个batch_size
for epoch in range(10):
for i, data in enumerate(train_loader, 0):
inputs, labels = data
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print('epoch:', epoch, 'i:', i, 'loss:', loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
4.DataLoader作用总结
将自定义的Dataset封装成一个Batch Size大小的Tensor,用于后面的训练。
参考来源:
pytorch之dataloader深入剖析