构建数据集大概步骤为:用各种方法(cv2,PIL,skimage等等)读取成对图片→通过某方法,返回类似(输入,标签)形式→转为tensor形式→传入Dataloader形成映射。
现在我需要一个成对图片的数据集,即输入与标签都是图片,且输入与标签在命名上完全相同。Pytorch中则是使用TORCH.UTILS.DATA下的Dataloader方法来构建数据集,即:
myDataset = torch.utils.data.DataLoader(dataset)
此处dataset支持的数据集形式之一便是映射(Map-style datasets)形式,这也正是我的数据集所需要的样子。而为了构建这样形式的数据集,需要对torch.utils.data中的Dataset类进行继承,并覆写__getitem__()方法,该方法的作用是对于给定的键值(key)返回对应的数据。当然,也可以根据需要覆写__len__()方法,用于得到数据集的长度(数量)。
那么接下来的事情就很明确了,构造一个继承torch.utils.data.Data的类,其返回值的形式为(input,label):
import torch
import torchvision
from torch.utils.data import Dataset
import os
from PIL import Image
from matplotlib import pyplot as plt
#对读取的图片采取的处理方法,详情自行搜索transforms的用法
transforms_imag=torchvision.transforms.Compose([torchvision.transforms.Resize([64,64]),
torchvision.transforms.ToTensor()])
#输入与标签图片所在的目录
input_root='./pic/input/'
label_root='./pic/label/'
class MyDataset(Dataset):#继承了Dataset子类
def __init__(self,input_root,label_root,transform=None):
#分别读取输入/标签图片的路径信息
self.input_root=input_root
self.input_files=os.listdir(input_root)#列出指定路径下的所有文件
self.label_root=label_root
self.label_files=os.listdir(label_root)
self.transforms=transform
def __len__(self):
#获取数据集大小
return len(self.input_files)
def __getitem__(self, index):
#根据索引(id)读取对应的图片
input_img_path=os.path.join(self.input_root,self.input_files[index])
input_img=Image.open(input_img_path)
#视频教程使用skimage来读取的图片,但我在之后使用transforms处理图片时会报错
#所以在此我修改为使用PIL形式读取的图片
label_img_path=os.path.join(self.label_root,self.label_files[index])
label_img=Image.open(label_img_path)
if self.transforms:
#transforms方法如果有就先处理,然后再返回最后结果
input_img=self.transforms(input_img)
label_img=self.transforms(label_img)
return (input_img,label_img)#返回成对的数据
###把以下代码放在return前,反注释后运行
# # test only for PIL#
# input_img.show()
# label_img.show()
# # test only for PIL#
接下来便是运行了:
dataset_train=MyDataset(input_root, label_root, transform=transforms_imag)
trainloader=torch.utils.data.DataLoader(dataset_train)
至此,一个自己的数据集便构建完成了,根据我在transforms里的处理,该数据集的图片皆被转换为了pytorch可以处理的tensor格式,并把图片尺寸修改为64x64大小。
可以使用以下方法读取数据集中的数据:
for b_index, (data,label) in enumerate(trainloader):
x = data
y = label
以上就是构建一个自己的数据集的过程了,这个方法也比较通用,应该会有不错的泛用性。
关于ImageFolder
有这么个读取图片的方法,torchvision.datasets.
ImageFolder,因为看起来很美好,所以一直在纠结怎么用它来一步到位,但喂入Dataloader后根本不是我想要的结果。该方法大概的意思是会根据你的文件夹来为图片自动添加标签。
最后倒推一下,整体思路是这样的:
1.Pytorch中使用卷积等运算时,比如torch.nn.Conv2d,它要求的输入格式为(N,C,H,W),其中N代表batch size。
2.输入格式是一个四维的量,而平时读取图片的方法能够获得的只有C,H,W这三个量,所以得想办法在原有的(C,H,W)上再多加一维。
3.此时Pytorch的torch.utils.data.DataLoader就提供了该方法,透过对其中参数batch_size的设置(默认为1),这样图片的格式就可以转换为带有N这一维了。
4.为了使用Dataloader(dataset),要将输入的dataset改成符合标准的格式。接着便是根据对dataset的注解来自行写相对应的类方法。
5.同时也别忘记pytorch计算时是使用的tensor(张量)格式的数据,所以在读取图片后要记得转换格式。上面代码中transforms的方法就包含了这一步。
参考:
2.文档对torch.utils.data.Dataloader的解释
3.文档对torch.utils.data.Dataset的解释
4.文档对torchvision.datasets.ImageFolder的解释