【Pytorch】17.使用自定义类继承Dataset创建数据集并通过代码将完整的数据集分割为测试集与训练集

源码

第一种情况MNIST_Training_MyDataset
第二种情况MNIST_Training_With_No_Split

简介

本文主要探讨获取一个数据集的两种情况,以手写数据集为例

  • 以文件格式划分好了训练集与测试集
    在这里插入图片描述
  • 文件没有划分测试集与训练集,需要通过代码进行划分
    在这里插入图片描述

1.文件格式已经划分好了训练集与数据集

这种数据集我们主要就是要通过自定义的类将文件格式的数据转化为可以进行训练的数据集,主要通过以下几步

  • 创建自定义的类,并继承Dataset
  • 重写Dataset的三个方法
    • __init__,用于根据数据集地址与数据转化类型来进行对数据初始化
    • __len__,用与获取数据集的长度
    • __getitem__,用于根据下标来获取数据集中的对应元素,并且返回图片与标签的二元组

基本代码结构为

class MyDataset(Dataset):
    def __init__(self, root_path, transform=None):
       pass

    def __len__(self):
        pass
        
    def __getitem__(self, index):
        pass


__init__定义

__init__主要实现一件事
将给出的数据集地址转化并保存为一个数据集列表
我们的文件结构为,下面的代码在必要处都给出了注释,读者可以自行阅读
在这里插入图片描述

    def __init__(self, root_path, train, transform=None):
        self.root_path = root_path
        # 判断变化规则
        self.transform = transform

        # 判断是否是训练集
        if train:
            self.data_path = os.path.join(self.root_path, 'training')
        else:
            self.data_path = os.path.join(self.root_path, 'testing')

        self.img_paths = []
        self.labels = []

        # 遍历每个子文件夹(标签)
        for label_dir in os.listdir(self.data_path):
            label_path = os.path.join(self.data_path, label_dir)
            if os.path.isdir(label_path):  # 只处理目录
                # 遍历子文件夹中的所有图像文件
                for img_name in os.listdir(label_path):
                    single_img_path = os.path.join(label_path, img_name)
                    # 将单个图片路径添加到img_paths中
                    self.img_paths.append(single_img_path)
                    # 将图片对应的标签添加到labels中
                    self.labels.append(int(label_dir))

__len__定义

有了__init__的定义,我们只需要返回len(img_paths)即可

    def __len__(self):
        return len(self.img_paths)

__getitem__定义

__getitem__主要实现两个功能

  • 从数据集中获取对应下标的图片,并且转化为给出的transform格式
  • 获取数据集中对应下标的标签
    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img_PIL = Image.open(img_path).convert('L')
        # 如果变化规则不为空
        if self.transform is not None:
            img_tensor = self.transform(img_PIL)
        else:
            img_tensor = img_PIL
        # 确定对应下标的标签
        label = self.labels[index]
        return img_tensor, label

使用

from My_Dataset import *

root_dir = '../datasets/mnist_png'
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

my_train_dataset = MyDataset(root_dir, train=True, transform=transform)
my_test_dataset = MyDataset(root_dir, train=False, transform=transform)

2.文件没有划分测试集与训练集,需要通过代码进行划分

具体流程与第一种情况类似,区别是

  • 额外定义了一个spilt_dataset划分数据集
  • __init__中对完整数据集进行划分,划分为训练集与数据集
  • 额外定义类Subset用于对划分结束的训练集与数据集规范化

__init__定义

具体的数据集路径为
在这里插入图片描述

    def __init__(self, root_path, transform=None):
        # 记录成员变量
        self.root_path = root_path
        self.transform = transform

        # 将图片与标签列为list
        self.imgs_path = []
        self.labels = []

        # 获取root_path下的所有图片
        for label_path in os.listdir(root_path):
            img_path = os.path.join(root_path, label_path)
            if os.path.isdir(img_path):
                for img_name in os.listdir(img_path):
                    pre_img_path = os.path.join(img_path, img_name)
                    self.imgs_path.append(pre_img_path)
                    self.labels.append(label_path)
        # 获取随机的random参数
        random.seed(random.seed)
        # 创建完整的数据集,内容为将图片和labels一一对应并且列为list
        data = list(zip(self.imgs_path, self.labels))
        # print(data)
        # 将数据集打乱
        random.shuffle(data)

        # 设置训练集数据集划分比例
        spilt_size = 0.8
        split = int(len(data) * spilt_size)
        # 根据比例划分训练集与数据集
        self.train_data = data[:split]
        self.test_data = data[split:]

相较于第一种方法,新增了许多代码,因为第一种方法从文件名就能知道训练集与数据集,而我们这种只能通过将完整的数据集打乱并且按比例取出训练集与数据集

__len__定义

与第一种方法没有区别

    def __len__(self):
        return len(self.imgs_path)

__getitem__定义

    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        label = self.labels[index]
        # 根据图片地址获取图片信息,并且转化为灰度图像
        img = Image.open(img_path).convert('L')
        if self.transform is not None:
            img = self.transform()
        return img, label

也无区别

spilt_dataset定义

这个方法主要是将划分完成的训练集与测试集进行返回,需要调用Subset

    def spilt_dataset(self):
        train_data = Subset(self, self.train_data, transform=self.transform)
        test_data = Subset(self, self.test_data, transform=self.transform)
        return train_data, test_data

这里不直接返回self.train_dataself.test_data是因为规范问题,自我感觉这样返回的两个数据集因为没有getitem方法,会导致在访问的时候出问题,但是自己也没有尝试,最好还是用这种格式吧

Subset类定义

class Subset(Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        img_path = self.indices[index][0]
        label = self.indices[index][1]
        img = Image.open(img_path).convert('L')
        if self.transform is not None:
            img = self.transform(img)
        # 需注意
        label = int(label)
        label = torch.tensor(label)
        return img, label

比较简单,唯一需要注意的是
我在用这种该方法获取数据集的标签label时,会出现获取的不是tensor数据类型而是tuple数据类型,就导致会出一些问题,网上搜到的解决方法是,先将标签转化为int类型,然后在用torch.tensor进行类型转化,这样训练就可以不出错了

使用方法

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),

])
root_dir = '../datasets/mnist_png_with_no_spilt'

my_dataset = MyDataset(root_dir, transform=transform)
train_dataset, test_dataset = my_dataset.spilt_dataset()
  • 10
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值