Pytorch学习笔记(II)——自定义数据集载入方式(一)

一、引言

  深度学习中主要分为两大任务,分类和回归。
  1、 分类即classification,就是将具有相同属性的样本划分为同一类,具有不同属性的样本划分为不同类。
  以往我们需要通过对样本打标签来划分类别,用0,1,2,3,…表示类别。而在Pytorch中只需要将同一类别的样本图片放在同一文件夹下,会自动将文件夹作为类别的区分。详细的操作与代码在之前的博客(Pytorch学习笔记(I)——预训练模型(一):加载与使用)中有介绍。即通过torchvision.datasets中已经封装好的ImageFolder载入分类任务的数据集。样例如下:

train_data=torchvision.datasets.ImageFolder('/disk2/lockonlxf/pin/trainData',transform=transforms.Compose(
                                                                        [
                                                                            transforms.Resize(256),
                                                                            transforms.CenterCrop(224),
                                                                            transforms.ToTensor()
                                                                        ]))
train_loader = DataLoader(train_data,batch_size=20,shuffle=True)

  2、 回归即regression,就是通过学习,将输入样本转变为另一种形式。那么标签就不再是0,1,2,3,…这样的分类了,而是每一个样本对应的GT(groundtruth)。以下简单举了几个例子。

任务类型输入(样本)输出(GT)
关键点检测图片所有关键点的坐标(x,y)
目标检测图片边框坐标或边框尺寸
显著性检测图片目标掩膜图片
人类重建图片图片

  因此,回归任务就不能ImageFolder载入数据集,大多情况下需要自定义数据集载入方式来满足自己的任务要求。
  接下来,我将介绍一种能够应用于大多数任务(多输入或多输出)的数据集载入方法。当然还有一种简易但不一定范用的方法Pytorch学习笔记(II)——自定义数据集载入方式(二)可以根据自身情况选择。

二、自定义数据集载入方式

  torch.utils.data.Dataset是一个表示数据集的抽象类,自定义数据集需要继承这个类,并且重写其以下内容:

__init__ :数据初始化
__len__ :返回数据库的大小
__getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本

1、准备工作

在重写之前,需要准备好对应的txt文件,为保证不出错,在txt中写上样本或GT的绝对路径,
在这里插入图片描述
如果是分类任务,还需要写上类别号。
在这里插入图片描述
总之,就是要将信息尽可能的写入txt中,记得用空格区分样本的多个信息,用换行区分样本。

注:我准备了3个txt文件,1)分类图片(含:绝对路径、类别号)2)回归图片(含:绝对路径)3)回归GT(含:绝对路径)
   接下来,我将一次性介绍分类和回归的载入方式,读者可根据自己的任务需求增减代码。

2、  init  

初始化不需要写太多,除了要载入的txt之外,transform一定要加!!!transform一定要加!!!transform一定要加!!!

def __init__(self,face,ex,mesh,transform=None):
        self.face = face
        self.ex = ex
        self.mesh = mesh
        self.transform = transform

3、  getitem  

这一步非常关键,是按照顺序读取txt中对应样本的信息。文件载入后,可以在这一步对文件进行处理,比如提取信息或数据转化等等

def __getitem__(self, index):
        faceline = linecache.getline(self.face,index+1)    #+1必须写,index的范围是0到num-1,而txt不存在第0行。
        faceline.strip('\n')
        face_list = faceline.split()                #通过空格对每一行的信息进行分割,所以[0]是索引样本的绝对路径,[1]是类别号
        face_name = face_list[0]                               
        I_face = Image.open(face_name)

        label = face_list[1]					#因为是从txt中读取类别号,所以类别号是字符型即“str”	
        label = torch.tensor(int(label))         #因此需要先将字符型转成整型,再转成tensor

        exline = linecache.getline(self.ex, index+1)
        exline.strip('\n')
        ex_list = exline.split()
        ex_name = ex_list[0]
        I_ex = Image.open(ex_name)

        meshline = linecache.getline(self.mesh, index+1)
        meshline.strip('\n')
        mesh_list = meshline.split()
        mesh_name = mesh_list[0]
        meshmatrix = np.load(mesh_name)
        meshmatrix = torch.FloatTensor(meshmatrix)
		
		# rgb转灰度,需求看个人
        # I_face = I_face.convert("L")
        # I_ex = I_ex.convert("L")
		
		#重点!!!
		#图片载入后,实际是PIL的形式,因此必须通过自带的transorm将图片转成tensor
        if self.transform:
            I_face = self.transform(I_face)
            I_ex = self.transform(I_ex)
		
		#最后返回读取到的数据,记住返回一定要是tensor的形式
        return I_face, I_ex, meshmatrix, label

4、  len  

这一步是计算样本的数量,其实只要计算txt有多少行就行了。一帮样本和GT的数量是一直的,所以读一个txt就可以了。

def __len__(self):
        fh = open(self.face, 'r')
        num = len(fh.readlines())
        fh.close()
        return num

5、载入

在对应位置,写上txt文件的绝对路径即可

train_data = MyDataset('/disk2/lockonlxf/experiments/reconstruction/face.txt',
                       '/disk2/lockonlxf/experiments/reconstruction/ex.txt',
                       '/disk2/lockonlxf/experiments/reconstruction/mesh.txt',transform=transforms.Compose(
                                                                        [
                                                                            # transforms.Resize(256),
                                                                            # transforms.CenterCrop(224),
                                                                            transforms.ToTensor()   #这一个必须要,其他根据任务自选。
                                                                            #更多transform指令可查看pytorch官网查询
                                                                        ]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)

三、完整代码

import torchvision
import torch
from torchvision import transforms,utils
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import linecache

#######自定义dataset
class MyDataset(Dataset):
    def __init__(self,face,ex,mesh,transform=None):
        self.face = face
        self.ex = ex
        self.mesh = mesh
        self.transform = transform

    def __getitem__(self, index):
        faceline = linecache.getline(self.face,index+1)
        faceline.strip('\n')
        face_list = faceline.split()
        face_name = face_list[0]
        I_face = Image.open(face_name)

        label = face_list[1]
        label = torch.tensor(int(label))

        exline = linecache.getline(self.ex, index+1)
        exline.strip('\n')
        ex_list = exline.split()
        ex_name = ex_list[0]
        I_ex = Image.open(ex_name)

        meshline = linecache.getline(self.mesh, index+1)
        meshline.strip('\n')
        mesh_list = meshline.split()
        mesh_name = mesh_list[0]
        meshmatrix = np.load(mesh_name)
        meshmatrix = torch.tensor(meshmatrix)

        # I_face = I_face.convert("L")
        # I_ex = I_ex.convert("L")
        if self.transform:
            I_face = self.transform(I_face)
            I_ex = self.transform(I_ex)

        return I_face, I_ex, meshmatrix, label

    def __len__(self):
        fh = open(self.face, 'r')
        num = len(fh.readlines())
        fh.close()
        return num

train_data = MyDataset('/disk2/lockonlxf/experiments/reconstruction/face.txt',
                       '/disk2/lockonlxf/experiments/reconstruction/ex.txt',
                       '/disk2/lockonlxf/experiments/reconstruction/mesh.txt',transform=transforms.Compose(
                                                                        [
                                                                            # transforms.Resize(256),
                                                                            # transforms.CenterCrop(224),
                                                                            transforms.ToTensor()
                                                                        ]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)

#注意看这里!!!如果自定义没有问题,下面的循环是可以跑通的,如果有问题,第一行for就会报错
#如果出错,可以将上面的shuffle设置为False,就是不打乱,然后在debug的时候看看是哪一个数据出了问题

for step, data in enumerate(train_loader):
    I_face, I_ex, meshmatrix, label = data
    #还可以看看,载入的尺寸是否正确,一般是会比原来多一维,代表的是batch_size
	print(I_face.shape)
	print(I_ex.shape)
	print(meshmatrix.shape)
	print(label.shape)

四、现有任务的载入方法

1、孪生网络(人脸识别/匹配)

Pytorch 入门之Siamese网络

2、人脸关键点检测

Pytorch自定义数据库

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值