一、引言
深度学习中主要分为两大任务,分类和回归。
1、 分类即classification,就是将具有相同属性的样本划分为同一类,具有不同属性的样本划分为不同类。
以往我们需要通过对样本打标签来划分类别,用0,1,2,3,…表示类别。而在Pytorch中只需要将同一类别的样本图片放在同一文件夹下,会自动将文件夹作为类别的区分。详细的操作与代码在之前的博客(Pytorch学习笔记(I)——预训练模型(一):加载与使用)中有介绍。即通过torchvision.datasets
中已经封装好的ImageFolder
载入分类任务的数据集。样例如下:
train_data=torchvision.datasets.ImageFolder('/disk2/lockonlxf/pin/trainData',transform=transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
]))
train_loader = DataLoader(train_data,batch_size=20,shuffle=True)
2、 回归即regression,就是通过学习,将输入样本转变为另一种形式。那么标签就不再是0,1,2,3,…这样的分类了,而是每一个样本对应的GT(groundtruth)。以下简单举了几个例子。
任务类型 | 输入(样本) | 输出(GT) |
---|---|---|
关键点检测 | 图片 | 所有关键点的坐标(x,y) |
目标检测 | 图片 | 边框坐标或边框尺寸 |
显著性检测 | 图片 | 目标掩膜图片 |
人类重建 | 图片 | 图片 |
因此,回归任务就不能
用ImageFolder
载入数据集,大多情况下需要自定义数据集载入方式来满足自己的任务要求。
接下来,我将介绍一种能够应用于大多数任务(多输入或多输出)的数据集载入方法。当然还有一种简易但不一定范用的方法Pytorch学习笔记(II)——自定义数据集载入方式(二)可以根据自身情况选择。
二、自定义数据集载入方式
torch.utils.data.Dataset
是一个表示数据集的抽象类,自定义数据集需要继承这个类,并且重写其以下内容:
__init__ :数据初始化
__len__ :返回数据库的大小
__getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本
1、准备工作
在重写之前,需要准备好对应的txt文件,为保证不出错,在txt中写上样本或GT的绝对路径,
如果是分类任务,还需要写上类别号。
总之,就是要将信息尽可能的写入txt中,记得用空格区分样本的多个信息,用换行区分样本。
注:我准备了3个txt文件,1)分类图片(含:绝对路径、类别号)2)回归图片(含:绝对路径)3)回归GT(含:绝对路径)
接下来,我将一次性介绍分类和回归的载入方式,读者可根据自己的任务需求增减代码。
2、 init
初始化不需要写太多,除了要载入的txt之外,transform一定要加!!!transform一定要加!!!transform一定要加!!!
def __init__(self,face,ex,mesh,transform=None):
self.face = face
self.ex = ex
self.mesh = mesh
self.transform = transform
3、 getitem
这一步非常关键,是按照顺序读取txt中对应样本的信息。文件载入后,可以在这一步对文件进行处理,比如提取信息或数据转化等等
def __getitem__(self, index):
faceline = linecache.getline(self.face,index+1) #+1必须写,index的范围是0到num-1,而txt不存在第0行。
faceline.strip('\n')
face_list = faceline.split() #通过空格对每一行的信息进行分割,所以[0]是索引样本的绝对路径,[1]是类别号
face_name = face_list[0]
I_face = Image.open(face_name)
label = face_list[1] #因为是从txt中读取类别号,所以类别号是字符型即“str”
label = torch.tensor(int(label)) #因此需要先将字符型转成整型,再转成tensor
exline = linecache.getline(self.ex, index+1)
exline.strip('\n')
ex_list = exline.split()
ex_name = ex_list[0]
I_ex = Image.open(ex_name)
meshline = linecache.getline(self.mesh, index+1)
meshline.strip('\n')
mesh_list = meshline.split()
mesh_name = mesh_list[0]
meshmatrix = np.load(mesh_name)
meshmatrix = torch.FloatTensor(meshmatrix)
# rgb转灰度,需求看个人
# I_face = I_face.convert("L")
# I_ex = I_ex.convert("L")
#重点!!!
#图片载入后,实际是PIL的形式,因此必须通过自带的transorm将图片转成tensor
if self.transform:
I_face = self.transform(I_face)
I_ex = self.transform(I_ex)
#最后返回读取到的数据,记住返回一定要是tensor的形式
return I_face, I_ex, meshmatrix, label
4、 len
这一步是计算样本的数量,其实只要计算txt有多少行就行了。一帮样本和GT的数量是一直的,所以读一个txt就可以了。
def __len__(self):
fh = open(self.face, 'r')
num = len(fh.readlines())
fh.close()
return num
5、载入
在对应位置,写上txt文件的绝对路径即可
train_data = MyDataset('/disk2/lockonlxf/experiments/reconstruction/face.txt',
'/disk2/lockonlxf/experiments/reconstruction/ex.txt',
'/disk2/lockonlxf/experiments/reconstruction/mesh.txt',transform=transforms.Compose(
[
# transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor() #这一个必须要,其他根据任务自选。
#更多transform指令可查看pytorch官网查询
]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)
三、完整代码
import torchvision
import torch
from torchvision import transforms,utils
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import linecache
#######自定义dataset
class MyDataset(Dataset):
def __init__(self,face,ex,mesh,transform=None):
self.face = face
self.ex = ex
self.mesh = mesh
self.transform = transform
def __getitem__(self, index):
faceline = linecache.getline(self.face,index+1)
faceline.strip('\n')
face_list = faceline.split()
face_name = face_list[0]
I_face = Image.open(face_name)
label = face_list[1]
label = torch.tensor(int(label))
exline = linecache.getline(self.ex, index+1)
exline.strip('\n')
ex_list = exline.split()
ex_name = ex_list[0]
I_ex = Image.open(ex_name)
meshline = linecache.getline(self.mesh, index+1)
meshline.strip('\n')
mesh_list = meshline.split()
mesh_name = mesh_list[0]
meshmatrix = np.load(mesh_name)
meshmatrix = torch.tensor(meshmatrix)
# I_face = I_face.convert("L")
# I_ex = I_ex.convert("L")
if self.transform:
I_face = self.transform(I_face)
I_ex = self.transform(I_ex)
return I_face, I_ex, meshmatrix, label
def __len__(self):
fh = open(self.face, 'r')
num = len(fh.readlines())
fh.close()
return num
train_data = MyDataset('/disk2/lockonlxf/experiments/reconstruction/face.txt',
'/disk2/lockonlxf/experiments/reconstruction/ex.txt',
'/disk2/lockonlxf/experiments/reconstruction/mesh.txt',transform=transforms.Compose(
[
# transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor()
]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)
#注意看这里!!!如果自定义没有问题,下面的循环是可以跑通的,如果有问题,第一行for就会报错
#如果出错,可以将上面的shuffle设置为False,就是不打乱,然后在debug的时候看看是哪一个数据出了问题
for step, data in enumerate(train_loader):
I_face, I_ex, meshmatrix, label = data
#还可以看看,载入的尺寸是否正确,一般是会比原来多一维,代表的是batch_size
print(I_face.shape)
print(I_ex.shape)
print(meshmatrix.shape)
print(label.shape)