Pytorch学习笔记(II)——自定义数据集载入方式(二)

一、引言

  深度学习中主要分为两大任务,分类和回归。
  1、 分类即classification,就是将具有相同属性的样本划分为同一类,具有不同属性的样本划分为不同类。
  以往我们需要通过对样本打标签来划分类别,用0,1,2,3,…表示类别。而在Pytorch中只需要将同一类别的样本图片放在同一文件夹下,会自动将文件夹作为类别的区分。详细的操作与代码在之前的博客(Pytorch学习笔记(I)——预训练模型(一):加载与使用)中有介绍。即通过torchvision.datasets中已经封装好的ImageFolder载入分类任务的数据集。样例如下:

train_data=torchvision.datasets.ImageFolder('/disk2/lockonlxf/pin/trainData',transform=transforms.Compose(
                                                                        [
                                                                            transforms.Resize(256),
                                                                            transforms.CenterCrop(224),
                                                                            transforms.ToTensor()
                                                                        ]))
train_loader = DataLoader(train_data,batch_size=20,shuffle=True)

  2、 回归即regression,就是通过学习,将输入样本转变为另一种形式。那么标签就不再是0,1,2,3,…这样的分类了,而是每一个样本对应的GT(groundtruth)。以下简单举了几个例子。

任务类型输入(样本)输出(GT)
关键点检测图片所有关键点的坐标(x,y)
目标检测图片边框坐标或边框尺寸
显著性检测图片目标掩膜图片
人类重建图片图片

  因此,回归任务就不能ImageFolder载入数据集,大多情况下需要自定义数据集载入方式来满足自己的任务要求。
  半年前,我写了一篇博客Pytorch学习笔记(II)——自定义数据集载入方式(一),介绍了一种能够应用于大多数任务(多输入或多输出)的数据集载入方法。可以说,这一种方法是一种傻瓜式教学,因为要把所有文件的路径保存到txt文件中,还得读取txt,略有点麻烦。
  本文将介绍一种简易的方法,但是不能保证适用于所有的任务。

二、自定义数据集载入方式

  torch.utils.data.Dataset是一个表示数据集的抽象类,自定义数据集需要继承这个类,并且重写其以下内容:

__init__ :数据初始化
__len__ :返回数据库的大小
__getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本

1、准备工作

  我这里还是以我自己的实验为例,一般我们做回归任务,都会有与输入样本配对的GT。然后,将对应的输入和GT放在一个文件夹里
在这里插入图片描述
  我一共有122450个训练样本,于是我就有122450个文件夹。当然如果你看完博客后有更好的存放方式,欢迎交流。
  接着,我们点开其中2个文件夹来看。可以看到,一共有4个文件,第一个是原图,第二个是经过裁剪的图片,第三个是图片的特征,第四个是由特征恢复的点云。
在这里插入图片描述
在这里插入图片描述
  这种方法对文件命名很讲究,同样为jpg文件,第二张图会有crop的前缀,同样为csv文件,后面两个用featurevertex两个前缀来区分。总的来说,属于同一类别需要用同样的关键词明明。如果文件夹里只有一个jpg图片,那么就直接用文件类型检索就好,具体可以看后续代码。
  接下来,将裁剪后的图片作为输入,特征文件作为GT。

2、  init  

  初始化不需要写太多,除了要载入的母文件夹路径之外,transform一定要加!!!transform一定要加!!!transform一定要加!!!

def __init__(self,path,transform=None):
        self.path = path
        self.transform = transform

  相比前一篇需要载入多个txt文件,这里只要载入我们的母文件夹路径,即只需一个输入。

3、  getitem  

  文件载入后,可以在这一步对文件进行处理,比如提取信息或数据转化等等

    def __getitem__(self, index):
        #image_path = os.path.join(self.face, str(index + 1), '*.jpg')
        image_path = os.path.join(self.face, str(index + 1), 'crop_*.jpg')
        image_name = glob.glob(image_path)[0]
        I_face = Image.open(image_name)
		##上面是载入图片,下面是载入csv,使用时根据个人情况修改
        feature_path = os.path.join(self.path, str(index + 1), 'feature_*.csv')
        feature_name = glob.glob(feature_path)[0]
        with open(feature_name) as feature_file:
            feat_reader = csv.reader(feature_file)  # Return an iterable reader object.
            label = []
            for element in feat_reader:
                label.append(float(element[0]))  # Each element is a list containing single string.
            mm228 = torch.tensor(label).reshape(-1, 1)

        if self.transform:
            I_face = self.transform(I_face)
        return I_face, mm228

4、  len  

这一步是计算样本的数量,其实只要计算母文件夹下有多少个文件夹就行了。

    def __len__(self):
        return len(os.listdir(self.face))

5、载入

在对应位置,写上母文件夹的绝对路径即可

train_data = MyDataset('/home/xxxxx/300w',
						transform=transforms.Compose(
							[
								transforms.Resize(256),
								transforms.CenterCrop(224),
								transforms.ToTensor() 
							]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)

三、完整代码

from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import torch
import os
import glob
import csv
#######自定义dataset
class MyDataset(Dataset):
    def __init__(self,path,transform=None):
        self.path = path
        self.transform = transform

    def __getitem__(self, index):
        #image_path = os.path.join(self.face, str(index + 1), '*.jpg')
        image_path = os.path.join(self.face, str(index + 1), 'crop_*.jpg')
        image_name = glob.glob(image_path)[0]
        I_face = Image.open(image_name)
		##上面是载入图片,下面是载入csv,使用时根据个人情况修改
        feature_path = os.path.join(self.path, str(index + 1), 'feature_*.csv')
        feature_name = glob.glob(feature_path)[0]
        with open(feature_name) as feature_file:
            feat_reader = csv.reader(feature_file)  # Return an iterable reader object.
            label = []
            for element in feat_reader:
                label.append(float(element[0]))  # Each element is a list containing single string.
            mm228 = torch.tensor(label).reshape(-1, 1)

        if self.transform:
            I_face = self.transform(I_face)
        return I_face, mm228

    def __len__(self):
        return len(os.listdir(self.face))

train_data = MyDataset('/home/xxxxx/300w',
						transform=transforms.Compose(
							[
								transforms.Resize(256),
								transforms.CenterCrop(224),
								transforms.ToTensor() 
							]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)
#注意看这里!!!如果自定义没有问题,下面的循环是可以跑通的,如果有问题,第一行for就会报错
#如果出错,可以将上面的shuffle设置为False,就是不打乱,然后在debug的时候看看是哪一个数据出了问题

for step, data in enumerate(train_loader):
    I_face, mm228 = data
    #还可以看看,载入的尺寸是否正确,一般是会比原来多一维,代表的是batch_size
	print(I_face.shape)
	print(mm228.shape)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值