Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)

Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)


Pytorch(五) 使用DataSet和DataLoader数据加载

在这篇文章中我已经简单的介绍了Dataset和DataLoader的简单用法,但是大多数实际情况中数据集的存储都没有那么简单,所以写了本文来记录一下如何自定义DataSet


介绍

在实际的案例当中,如图像分类等任务来说,我们需要训练的数据集往往是存储在一个文件夹中的,而数据集的存储格式都是类似的
以蚂蚁和蜜蜂图片数据集 hymenoptera_data 来举例
在这里插入图片描述
一般的数据集都会分为两个文件夹

  • train 训练集
  • val 测试集

打开训练集之后的数据存储又分为两种情况

情况1

对于图像分类来说, 肯定需要一个 label和一个img
有些数据集喜欢把它们分开成两个文件夹
在这里插入图片描述
img文件夹 中存放的是图片
在这里插入图片描述
label文件夹中存放的是标签,通常以txt文件来存储,文件名和图片名相同,而文件的内容代表了图片的标签
在这里插入图片描述

情况2

对于一些简单的数据集来说,可能不会把labelimg分开存放
比如情况1中提到的蚂蚁蜜蜂数据集
ants目录下的全是蚂蚁的图片
bees文件夹下全是蜜蜂的图片
这里的文件夹名就代表了图片的label
不过常用的情况 是把图片的label包含在了图片的命名当中
如下图
在这里插入图片描述

自定义Dataset

Dataset的主要作用就是提供一种方式来获取数据和其label
自定义的Dataset需要满足如下两个功能

  • 如何获取每一个数据和其label
  • 告诉我们数据集一共有多少个数据

导入库

from PIL import Image
from torch.utils.data import Dataset
import os

如果没有下载相应的库就用 pip下载一下

TensorDataset回顾

在我之前的文章中提到,对于简单的数据可以用TensorDataset来包装
而通过for循环也可以遍历取出Dataset的中的数据和 label
它的原理就是内置了一个方法可以通过 index 来获取到相应的数据
下面就写个小案例来回顾一下TensorDataset的使用

构建数据x和标签y

x = torch.tensor(
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3],
     [4, 5, 6], [7, 8, 9]])
y = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])`

在这里插入图片描述

from torch.utils.data import TensorDataset
train_dataset = TensorDataset(x, y)
print(train_dataset[0])

在这里插入图片描述
显然,如上图,train_dataset中数据的存放格式也是一个数据加一个标签的元组形式,并且可以通过 index来获取
当然在实际问题中也可以之间用for循环来遍历

Dataset自定义实现

那么对于我们的图像数据来说,要想达到遍历取数据和label的效果,需要我们自定义Dataset
在这里我们需要达到的目标是 通过 index 可以获取到一个 包含图像数据和标签的元组
并且要知道数据集的中长度
因此Dataset的初步模板就出来了, 如下

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        pass

    def __len__(self):
        pass

其中 getitem的作用是通过index来返回一个包含数据和标签的元组
len的作用是返回数据集的长度
init很明显是类构造器
下面就来一步一步的实现它

获取图片并显示

PILImage可以帮助我们解析并查看一张图片
本案例的猫狗数据集我放在了D:\Source\Datasets\cat_and_dog
简单的读取一个图片来看看

from PIL import Image
img_path = r'D:\Source\Datasets\cat_and_dog\train\cat.10.jpg'
img = Image.open(img_path)
img.show()

在这里插入图片描述
通过 Image.open(img_path)方法读取到的是一个PIL.JpegImagePlugin.JpegImageFile对象,它包含了很多东西,我们就把它当做图像的数据

在这里插入图片描述
显然,读取图像需要图像的完整路径,那么思路很明显,我们可以把图像的路径存成一个列表,然后就可以通过index来获取列表中的值,(这里取出的是图像的完整路径),然后再通过Image把它解析成数据,取出其标签,返回

完成getitem方法

这一步其实不难,重点在理解其思路
下图就达到了返回img图像列表的操作,最后只需要把图像和train_path拼接起来就可以获取到图像的完整路径
在这里插入图片描述

完整代码

代码如下

# -*- coding: utf-8 -*-
# @Time    : 2021/1/31 11:01
# @Author  : Tong Tianyu
# @File    : demo.py
from PIL import Image
from torch.utils.data import Dataset
import os


class MyDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.img_list = os.listdir(self.data_path)

    def __getitem__(self, index):
        img_title = self.img_list[index]
        img_label = img_title.split('.')[0]
        img_path = os.path.join(self.data_path, img_title)
        img = Image.open(img_path)
        return img, img_label

    def __len__(self):
        return len(self.img_list)


train_path = r'D:\Source\Datasets\cat_and_dog\train'
train_dataset = MyDataset(train_path)

效果如下, 完成目标
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Joker-Tong

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值