PyTorch系列 | 自定义dataset(torch.utils.data.Dataset与torchvision.datasets.ImageFolder)

0 引言

Pytorch使用方法一自定义 dataset 时,需要重写 __len____getitem__

  • __len__ 提供 dataset 的大小
  • __getitem__ 提供 dataset 的索引
  • 方法二则不需要重写,直接使用即可

在 Python 对象中,需要重写的双下划线开头和结尾的属性称为特殊属性,常见的有对象的名称:__name__

另外对象的方法也属于属性,因此以双下划线开头和结尾的方法称为特殊方法,例如上述需要重写的 __len____getitem__ 便是两个特殊方法

常见的特殊属性和特殊方法,可见我的另一篇博客:Python系列 | 常见的特殊属性与特殊方法

1 源数据介绍

示例数据集来源于:LFW人脸数据集

  • 累计 13233 张图片
  • 人脸图像采集自 5749 人
  • 其中 1680 人有两张及以上的图片

LFW人脸数据集以人名作为文件名,文件夹下为相对应的人脸图像:

在这里插入图片描述

图片均以 “人名_000x” 的形式命名,以Abdullah为例:

在这里插入图片描述

2 代码实现

若读者对 os 模块不太熟悉,可参考我另一篇博客:Python系列 | os模块常用命令

由于数据集的特殊性,每张图片都处于二级文件下,因此在正式定义 dataset 之前,有必要对数据集进行一定处理,将所有图像整合至一个文件夹中,具体代码如下:

import os
import shutil


def make_file(path):
    if os.path.exists(path):
        os.rmdir(path)
        os.mkdir(path)
    else:
        os.mkdir(path)


def main():
    root = os.path.join(os.getcwd(), 'lfw')
    image_file = os.listdir(root)
    image_set = list()
    for file in image_file:
        image_path = os.path.join(root, file)
        image_list = os.listdir(image_path)
        for image in image_list:
            image_set.append(os.path.join(image_path, image))

    new_path = os.path.join(os.getcwd(), 'lfw_dataset')
    make_file(new_path)
    for path in image_set:
        shutil.copy(path, new_path)
        
    print('Done !')


if __name__ == '__main__':
    main()

实现结果:

在这里插入图片描述

将所有图像汇总后,即可使用方法一或方法二定义 dataset 。

2.1 方法一

使用 torch.utils.data.Dataset

import torchvision.transforms as transforms
from torch.utils.data import Dataset
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt


class LfwDataset(Dataset):  # 继承Dataset,复写__getitem__和__len__
    def __init__(self, root, transform):
        self.root = root
        self.images = [os.path.join(self.root, path) for path in os.listdir(self.root)]  # 图像路径集合
        self.transform = transform  # transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        image_path = self.images[item]  # 图像索引,获取单张图像路径
        image = Image.open(image_path)
        _, image_name = os.path.split(image_path)
        label, _ = image_name.split('.')
        label = label[:-5]

        if self.transform is not None:
            image = self.transform(image)

        return image, label


if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(0.5)
    ])
    lfw_dataset = LfwDataset(root=r'.\lfw_dataset', transform=transform)
    data_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=12, shuffle=True)

    for i, (inputs, target) in enumerate(data_loader):
        if i == 0:  # 输出一部分看看
            print(inputs.shape)
            print(target)
            plt.figure(figsize=(12, 16))
            for num in range(12):  # 确认图像是否可以正确读取
                plt.subplot(3, 4, num + 1)
                plt.imshow(inputs[num].permute([1, 2, 0]))
                plt.title(target[num], size=13)
                plt.axis('off')
            plt.tight_layout()
            plt.show()
        else:
            break

打印结果:

torch.Size([12, 3, 250, 250])
('Christopher_Conyers', 'Michael_Jackson', 'Edward_Johnson', 'Heizo_Takenaka', 'Ai_Sugiyama', 'Lawrence_MacAulay', 'Geno_Auriemma', 'Bustam_A_Zedan_Aljanabi', 'Colin_Powell', 'Hugh_Grant', 'Ellen_Martin', 'Billy_Sollie')

图像输出结果:

在这里插入图片描述
可见,已自定义完成 dataset 。

2.2 方法二

使用 torchvision.datasets.ImageFolder

方法二较方法一要更为方便,但 torchvision.datasets.ImageFolder要求图片文件以下图格式进行排列:

在这里插入图片描述

也就是说,每个类别的图像要各自为一个文件夹,这也正好符合本示例 LFW 人脸数据集的特点。

这里还有几个注意点:

  • 所定义的 dataset 数据集的类别标签储存于 dataset.classes
  • 使用 torch.utils.data.DataLoader 加载 dataset 时,其类别标签返回的是相应类别的索引,而非类别标签本身
  • 在训练模型时,直接使用类别标签的索引作为 target ,若有需要,可在训练结束后进行索引和类别的转换即可
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import torch


def main():
    root = os.path.join(os.getcwd(), 'lfw')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(0.5)
    ])

    lfw_dataset = torchvision.datasets.ImageFolder(root=root, transform=transform)
    lfw_dataloader = DataLoader(lfw_dataset, batch_size=12, shuffle=True)

    for i, (inputs, target_index_set) in enumerate(lfw_dataloader):
        if i == 0:
            print(f'inputs.shape : {inputs.shape}')
            print(f'target_index_set: {target_index_set}')
            plt.figure(figsize=(12, 16))
            for num in range(12):
                plt.subplot(3, 4, num + 1)
                plt.imshow(inputs[num].permute([1, 2, 0]))
                plt.title(lfw_dataset.classes[target_index_set[num]], size=13)
                plt.axis('off')
            plt.tight_layout()
            plt.show()
        else:
            break


if __name__ == '__main__':
    main()

打印结果:

inputs.shape : torch.Size([12, 3, 250, 250])
target_index_set: tensor([  59, 3995, 3620, 3092, 1155, 4900, 5564, 5639,  809, 5685, 1995, 4257])

图片输出结果:

在这里插入图片描述

可见,方法二的自定义 dataset 要方便很多,只是对数据的存储方式有一定要求。

  • 7
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值