【Pytorch学习】Dataset和DataLoader
1.训练过程
- 创建Dataset类
- Dataset传递给DataLoader
- DataLoader迭代产生训练数据给模型
2.Dataset
Dataset
是一个抽象类,自定义数据集类必须继承 Dataset,且需重新改写 getitem 和 len 方法。
__getitem__
:传入指定的索引 index ,根据索引,以元组形式返回对应的单个样本及其对应的标签。__len__
:返回整个数据集的大小。
2.1框架
class MyDataset(Dataset):
def __init__(self,annotation_lines):
super(MyDataset, self).__init__()
# 初始化
def __getitem__(self, index):
# 返回单个样本及其标签
pass
def __len__(self):
# 返回整个数据集的大小
pass
3.DataLoader
3.1 基本工作机制
在dataloader按照batch进行取数据,
- 取出大小等同于batch size的index列表,
- 将列表中的index输入到dataset的getitem()函数中,取出该index对应的数据,
- 对每个index对应的数据进行堆叠,就形成了一个batch的数据。
3.2 常用参数
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
- dataset :自定义数据集类的实例。
- batch_size :每个 batch 的大小。
- shuffle :用来控制是否在每个 epoch 开始时打乱数据集。
- num_workers :决定加载数据集时使用多少子进程。默认只使用主进程加载。
- collate_fn : 如何取样本的,可以定义自己的函数来准确地实现想要的功能
- drop_last :当 batch size 不能整除 data size 时,最后一个 batch 会变得不完整。该参数决定是否扔掉最后一个不完整的 batch。
3.2 collate_fn
通常可以将数据格式转换放在这一步定义,如Dataset__getitem__
输出是ndarray,可以在此函数转换为Tensor
def collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
return images, bboxes
参考:https://raelum.blog.csdn.net/article/details/124578203
https://blog.csdn.net/yuanzhoulvpi/article/details/122284733