【Pytorch学习】Dataset和DataLoader

【Pytorch学习】Dataset和DataLoader

1.训练过程
  • 创建Dataset类
  • Dataset传递给DataLoader
  • DataLoader迭代产生训练数据给模型
2.Dataset

Dataset 是一个抽象类,自定义数据集类必须继承 Dataset,且需重新改写 getitemlen 方法。

  • __getitem__ :传入指定的索引 index ,根据索引,以元组形式返回对应的单个样本及其对应的标签。
  • __len__ :返回整个数据集的大小。
2.1框架
class MyDataset(Dataset):
    def __init__(self,annotation_lines):
    		super(MyDataset, self).__init__()
        # 初始化
        
    
    def __getitem__(self, index):
        # 返回单个样本及其标签
        pass
    
    def __len__(self):
        # 返回整个数据集的大小
        pass

3.DataLoader
3.1 基本工作机制

在dataloader按照batch进行取数据,

  1. 取出大小等同于batch size的index列表,
  2. 将列表中的index输入到dataset的getitem()函数中,取出该index对应的数据,
  3. 对每个index对应的数据进行堆叠,就形成了一个batch的数据。
3.2 常用参数
class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)
  • dataset :自定义数据集类的实例。
  • batch_size :每个 batch 的大小。
  • shuffle :用来控制是否在每个 epoch 开始时打乱数据集。
  • num_workers :决定加载数据集时使用多少子进程。默认只使用主进程加载。
  • collate_fn : 如何取样本的,可以定义自己的函数来准确地实现想要的功能
  • drop_last :当 batch size 不能整除 data size 时,最后一个 batch 会变得不完整。该参数决定是否扔掉最后一个不完整的 batch。
3.2 collate_fn

通常可以将数据格式转换放在这一步定义,如Dataset__getitem__ 输出是ndarray,可以在此函数转换为Tensor

def collate(batch):
    images = []
    bboxes = []
    for img, box in batch:
        images.append(img)
        bboxes.append(box)
    images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
    return images, bboxes

参考:https://raelum.blog.csdn.net/article/details/124578203

​ https://blog.csdn.net/yuanzhoulvpi/article/details/122284733

  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值