(深度学习)构造属于你自己的Pytorch数据集

(深度学习)构造属于你自己的Pytorch数据集

1.综述

2.实现原理

3.代码细节

4.详细代码







综述

Pytorch可以说是一个非常便利的深度学习库,它甚至在torchvision.datasets中拥有许多一步到位完成数据集下载、解析、读取的类——然鹅,这样也就养成了我们懒惰依赖的心理。当我们需要用到torchvision.datasets中不曾拥有的数据集时,我们可能就会不知所措。

这篇文章中,我将以CIFAR-10数据集为例(虽然有torchvision.datasets.CIFAR10了),摆脱对torchvision.datasets的依赖,构建一个自己的数据集。

在开始之前,首先你要有CIFAR-10数据集,直接去官网上下载可能较慢(再次感谢我国著名建筑师方斌新院士 ),可以在https://pan.baidu.com/s/1bGVGeeiw001qz-PUk7q1Uw(提取码:m35y)中下载python版本的数据集。

数据集解压后目录情况如下:
在这里插入图片描述





实现原理

首先,torch.utils.data.DataLoader不仅生成迭代数据非常方便,而且它也是经过优化的,效率十分之高(肯定比我们自己写一个要高多了),因此我们最好不要舍弃。

因此,我们的目标是根据CIFAR-10数据集构造一个Dataset的子类,使之能够作为torch.utils.data.DataLoader的参数,从而使数据集能被我们用于生成迭代数据进行训练:

cifar10 = MyCIFAR10.MyCIFAR10('./data/cifar-10-batches-py', train=True)
train_loader = torch.utils.data.DataLoader(dataset=cifar10, batch_size=batch_size, shuffle=True)

要构造Dataset的子类,就必须要实现两个方法:

  • _getitem_(self, index):根据index来返回数据集中标号为index的元素及其标签。
  • _len_(self):返回数据集的长度。

因此,实质上我们主要是要通过__init__初始化之时读取数据集,再实现这两个函数便轻而易举。





代码细节

  1. _init_:

    • root是存放解压后的数据集的根目录,根据上图我这里是'./data/cifar-10-batches-py'
    • X的类型是numpy数组,Y的类型是List;由于X作为数据要送入网络中,因此最后需要将其累加值从numpy数组转为Tensor。
    def __init__(self, root, train=True, transform=None, target_transform=None):
        super(MyCIFAR10, self).__init__()
        self.transform = transform
        self.target_transform = target_transform
        self.imgs = None
        self.labels = []
    
        # 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片
        train_lists = ['data_batch_1',
                       'data_batch_2',
                       'data_batch_3',
                       'data_batch_4',
                       'data_batch_5']
        test_lists = ['test_batch']
    
        # 根据train是否为True来选择测试集或训练集
        if train:
            lists = train_lists
        else:
            lists = test_lists
    
        # 读取数据集,构造类中的图像集和标签
        for list in lists:
            filename = os.path.join(root, list)
            with open(filename, 'rb') as f:  # 这里需要'rb' + 'latin1'才能读取
                datadict = pickle.load(f, encoding='latin1')
                X = datadict['data'].reshape(-1, 3, 32, 32)
                Y = datadict['labels']
                if self.imgs is None:
                    self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)
                else:
                    self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)
                self.labels = self.labels + Y
        self.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor)     # 最后需要将numpy数组转为Tensor
    
  2. _getitem_:

    较为简单,直接给出:

    def __getitem__(self, index):
        img, label = self.imgs[index], self.labels[index]
    
        if self.transform is not None:
            img = self.transform(img)
    
        if self.target_transform is not None:
            label = self.target_transform(label)
    
        return img, label
    
  3. _len_:

    极其简单,直接给出:

    def __len__(self):
        return len(self.imgs)
    





详细代码

class MyCIFAR10(Dataset):
    """
    根据CIFAR-10定义的个人数据集类
    继承自Dataset类,因此能够被torch.utils.data.DataLoader使用,从而更高效地在训练和测试中迭代
    """

    def __init__(self, root, train=True, transform=None, target_transform=None):
        super(MyCIFAR10, self).__init__()
        self.transform = transform
        self.target_transform = target_transform
        self.imgs = None
        self.labels = []

        # 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片
        train_lists = ['data_batch_1',
                       'data_batch_2',
                       'data_batch_3',
                       'data_batch_4',
                       'data_batch_5']
        test_lists = ['test_batch']

        # 根据train是否为True来选择测试集或训练集
        if train:
            lists = train_lists
        else:
            lists = test_lists

        # 读取数据集,构造类中的图像集和标签
        for list in lists:
            filename = os.path.join(root, list)
            with open(filename, 'rb') as f:  # 这里需要'rb' + 'latin1'才能读取
                datadict = pickle.load(f, encoding='latin1')
                X = datadict['data'].reshape(-1, 3, 32, 32)
                Y = datadict['labels']
                if self.imgs is None:
                    self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)
                else:
                    self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)
                self.labels = self.labels + Y
        self.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor)     # 最后需要将numpy数组转为Tensor

    # 继承的Dataset类需要实现两个方法之一:__getitem__(self, index)
    def __getitem__(self, index):
        img, label = self.imgs[index], self.labels[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return img, label

    # 继承的Dataset类需要实现两个方法之一:__len__(self)
    def __len__(self):
        return len(self.imgs)

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值