PyTorch学习(一):Dataset && DataLoader

Dataset && DataLoader

前言:为了易读以及适应python的模块化编程,PyTorch提供了两个加载数据的原型,分别为:torch.utils.data.Dataset以及torch.utils.data.DataLoader,其中Dataset存储了数据集的样本已经相应的标签,DataLoader将其进一步进行包装成为一个迭代器使得我们可以更容易的从中获取训练样本

PyTorch库中已经封装大量的数据集(例如FashionMNIST),可以从下面两个链接中找到使用方法,这里仅对你自己定制的数据集进行解释说明。链接一链接二

Dataset

自己定制的Dataset继承于torch.utils.data.Dataset,需要实现三个函数:__init____len____getitem__

  • init:此函数在实例化Dataset类时被调用,用来初始化:包含图片的文件夹,标注文件以及transforms
  • len:返回样本的总数量
  • getitem:从给定索引index中索引一个样本,并转化为张量,通过transform对其进行处理, 最后返回张量以及相应的标签等

具体的流程是,首先我们在一个文件夹A下有训练集,然后我们将A中的图片在init初始化,之后len就是训练的总数,getitem是在训练时对其进行索引,在索引是会进行一系列的操作,例如旋转等

给出Dataset的代码示例

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])

class CustomDataset(Dataset):
    def __init__(self, image_dir):
        super(CustomDataset, self).__init__()
        self.image_dir = image_dir

        inp = sorted(os.listdir(os.path.join(self.image_dir)))
        self.inp_filenames = [os.path.join(self.image_dir, x) for x in inp if is_image_file(x)]

    def __len__(self):
        return len(self.inp_filenames)

    def __getitem__(self, index):
        input_path = self.inp_filenames[index]
        input = Image.open(input_path)
        inp_img = TF.to_tensor(input)
        return inp_img

DataLoader

只有Dataset是不够的,因为在训练时我们常常需要将数据集划分为Batch进行训练,并且为了防止过拟合,我们希望数据集在每个epoch都是被打乱的

使用方法也比较简单,只需要将之前得到的Dataset包装进去即可

train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
for _, img in enumerate(DataLoader):
    input = img
    ...
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值