目录
torch.utils.data.Dataset()和torch.utils.data.DataLoader()
数据集
本文使用 PASCAL VOC 数据集。
下载完成数据集之后进行解压,我们可以在 ImageSets/Segmentation/train.txt 和 ImageSets/Segmentation/val.txt 中找到我们的训练集和验证集的数据,图片存放在 /JPEGImages 中,后缀是 .jpg,而 label 存放在 /SegmentationClass 中,后缀是 .png
图片读入
首先我们定义一个readImage函数进行图片的读入,根据 train.txt 和 val.txt 中的文件名进行图片读入,我们不需要这一步就读入图片,只需要知道图片的路径,之后根据图片名称生成 batch 的时候再读入图片,并做一些数据预处理
# 读取图像和标签信息
prefix = "C:/Users/Administrator/PycharmProjects/FCN/VOCdevkit/VOC2012/"
"""
读取图片
图片的名称在/ImageSets/Segmentation/train.txt ans val.txt里
如果传入参数train为True,则读取train.txt的内容,否则读取val.txt的内容
图片都在./data/VOC2012/JPEGImages文件夹下面,需要在train.txt读取的每一行后面加上.jpg
标签都在./data/VOC2012/SegmentationClass文件夹下面,需要在读取的每一行后面加上.png
最后返回记录图片路径的集合data和记录标签路径集合的label
"""
def readImage(self):
img_root = prefix + "JPEGImages/"
label_root = prefix + "SegmentationClass/"
if (self.mode == "train"):
with open(prefix + "ImageSets/Segmentation/train_try.txt", "r") as f:
list_dir = f.readlines()
elif (self.mode == "val"):
with open(prefix + "ImageSets/Segmentation/val_try.txt", "r") as f:
list_dir = f.readlines()
for item in list_dir:
self.image_name.append(img_root + item.split("\n")[0] + ".jpg")
self.label_name.append(label_root + item.split("\n")[0] + ".png")
预处理
crop
图片的大小是不固定的,但要使用一个 batch 进行计算,我们需要图片的大小保持一致,我们使用 crop 的方式来解决这个问题,也就是从一张图片中 crop 出固定大小的区域,然后在 label 上也做同样方式的 crop。
# 进行随机裁剪
"""
切割函数,从(st,st)开始切割,左上角为(0,0)
切割后的图片宽为width,长为height
"""
width, height = img.size
st = random.randint(0, 20)
box = (st, st, width - 1, height - 1)
img = img.crop(box)
img_gt = img_gt.crop(box)
也可以使用 pytorch 中自带的 transforms进行 crop ,不仅输出 crop 出来的区域,同时还要输出对应的坐标便于我们在 label 上做相同的 crop。
#随即裁剪
def rand_crop(img, img_gt, height, width):
img, rect = tfs.RandomCrop((height, width))(img)
img_gt = tfs.FixedCrop(*rect)(img_gt)
return img, img_gt
标签和像素点颜色
数据有 21 中类别,同时给出每种类别对应的 RGB 值。
# 需要将标签和像素点颜色之间建立映射关系
# voc数据集对应类别标签,一共有20+1个类
self.classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
'dog', 'horse', 'motorbike', 'person', 'potted plant',
'sheep', 'sofa', 'train', 'tv/monitor']
# 颜色标签,分别对应21个类别
self.colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
[64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128],
[64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0],
[0, 192, 0], [128, 192, 0], [0, 64, 128]]
接着可以建立一个索引,也就是将一个类别的 RGB 值对应到一个整数上,通过这种一一对应的关系,能够将 label 图片变成一个矩阵,矩阵和原图片一样大,但是只有一个通道数,也就是 (h, w) 这种大小,里面的每个数值代表着像素的类别。因为图片是三通道的,并且每一个通道都有0-255一共256中取值,所以我们初始化一个256^3大小的数组就可以做映射了。
# 每个像素点有 0 ~ 255 的选择,RGB 三个通道
cm2lbl = np.zeros(256**3)
# 枚举的时候i是下标,cm是一个三元组,分别标记了RGB值
for i,cm in enumerate(colormap):
cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i # 建立索引
# 将标签按照RGB值填入对应类别的下标信息
def image2label(im):
data = np.array(im, dtype='int32')
idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
return np.array(cm2lbl[idx], dtype='int64') # 根据索引得到 label 矩阵
#可以看一下
im = Image.open(label[20]).convert("RGB")
label_im = image2label(im)
plt.imshow(im)
plt.show()
label_im[100:110, 200:210]
随机翻转
# 以50%的概率左右翻转
a = random.random()
if (a > 0.5):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img_gt = img_gt.transpose(Image.FLIP_LEFT_RIGHT)
# 以50%的概率上下翻转
a = random.random()
if (a > 0.5):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_gt = img_gt.transpose(Image.FLIP_TOP_BOTTOM)
# 以50%的概率像素矩阵转置
a = random.random()
if (a > 0.5):
img = img.transpose(Image.TRANSPOSE)
img_gt = img_gt.transpose(Image.TRANSPOSE)
噪声
def add_noise(self, img, gama=0.2):
noise = torch.randn(img.shape[0], img.shape[1], img.shape[2])
noise = noise * gama
img = img + noise
return img
标准化
# 将数据转换成tensor,并且做标准化处理
self.im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 将mat格式的数据转换成png格式
if (change == True):
self.mat2png()
torch.utils.data.Dataset()和torch.utils.data.DataLoader()
在pytorch中,提供了一种十分方便的数据读取机制,即,使用torch.utils.data.Dataset与torch.utils.data.DataLoader组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个batch数据,并能在输出时对数据进行相应的预处理或数据增强等操作。
torch.utils.data.Dataset
torch.utils.data.Dataset
是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__()
和__getitem__()
这两个魔法方法。
- __len__():返回的是数据集的大小。我们构建的数据集是一个对象,而数据集不像序列类型(列表、元组、字符串)那样可以直接用len()来获取序列的长度,魔法方法__len__()的目的就是方便像序列那样直接获取对象的长度。如果A是一个类,a是类A的实例化对象,当A中定义了魔法方法__len__(),len(a)则返回对象的大小。
- __getitem__():实现索引数据集中的某一个数据。我们知道,序列可以通过索引的方法获取序列中的任意元素,__getitem__()则实现了能够通过索引的方法获取对象中的任意元素。此外,我们可以在__getitem__()中实现数据预处理。
import torch
from torch.utils.data import Dataset
class TensorDataset(Dataset):
"""
TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
实现将一组Tensor数据对封装成Tensor数据集
能够通过index得到数据集的数据,能够通过len,得到数据集大小
"""
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)
# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)
# 可使用索引调用数据
print(tensor_dataset[1])
# 输出:(tensor([-1.0351, -0.1004, 0.9168]), tensor(0.4977))
# 获取数据集大小
print(len(tensor_dataset))
# 输出:4
torch.utils.data.DataLoader
DataLoader将Dataset对象或自定义数据类的对象封装成一个迭代器;这个迭代器可以迭代输出Dataset的内容;同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程。
__init__()中的几个重要的输入:
- dataset:这个就是pytorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。
- batch_size:根据具体情况设置即可。
- shuffle:随机打乱顺序,一般在训练数据中会采用。
- collate_fn:是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。
- batch_sampler:从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。
- sampler:从代码可以看出,其和shuffle是互斥的,一般默认即可。
- num_workers:从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。
- pin_memory:注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。
- timeout:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
tensor_dataloader = DataLoader(tensor_dataset, # 封装的对象
batch_size=2, # 输出的batch size
shuffle=True, # 随机输出
num_workers=0) # 只有1个进程
# 以for循环形式输出
for data, target in tensor_dataloader:
print(data, target)
#输出结果
tensor([[ 0.7745, 0.2186, 0.1231],
[-0.1307, 1.5778, -1.2906]]) tensor([0.3749, 0.4659])
tensor([[-0.1605, 0.9359, 0.1314],
[-1.1694, 1.0986, -0.9927]]) tensor([0.8071, 0.8997])
完整代码
#完整代码
import torch
import torchvision.transforms as tfs
import os
import scipy.io as scio
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random
# PASCAL VOC语义分割增强数据集
prefix = "C:/Users/Administrator/PycharmProjects/FCN/VOCdevkit/VOC2012/"
# 超参数,设置裁剪的尺寸
CROP = 256
class PASCAL_BSD(object):
def __init__(self, mode="train", change=False):
super(PASCAL_BSD, self).__init__()
# 需要将标签和像素点颜色之间建立映射关系
# 读取数据的模式
self.mode = mode
# voc数据集对应类别标签,一共有20+1个类
self.classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
'dog', 'horse', 'motorbike', 'person', 'potted plant',
'sheep', 'sofa', 'train', 'tv/monitor']
# 颜色标签,分别对应21个类别
self.colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
[64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128],
[64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0],
[0, 192, 0], [128, 192, 0], [0, 64, 128]]
# 将数据转换成tensor,并且做标准化处理
self.im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 将mat格式的数据转换成png格式
if (change == True):
self.mat2png()
self.image_name = []
self.label_name = []
self.readImage()
print("%s->成功加载%d张图片" % (self.mode, len(self.image_name)))
"""
读取图片
图片的名称在/ImageSets/Segmentation/train.txt ans val.txt里
如果传入参数train为True,则读取train.txt的内容,否则读取val.txt的内容
图片都在./data/VOC2012/JPEGImages文件夹下面,需要在train.txt读取的每一行后面加上.jpg
标签都在./data/VOC2012/SegmentationClass文件夹下面,需要在读取的每一行后面加上.png
最后返回记录图片路径的集合data和记录标签路径集合的label
"""
# 读取图像和标签信息
def readImage(self):
img_root = prefix + "JPEGImages/"
label_root = prefix + "SegmentationClass/"
if (self.mode == "train"):
with open(prefix + "ImageSets/Segmentation/train.txt", "r") as f:
list_dir = f.readlines()
elif (self.mode == "val"):
with open(prefix + "ImageSets/Segmentation/val.txt", "r") as f:
list_dir = f.readlines()
for item in list_dir:
self.image_name.append(img_root + item.split("\n")[0] + ".jpg")
self.label_name.append(label_root + item.split("\n")[0] + ".png")
# 数据处理,输入Image对象,返回tensor对象
def data_process(self, img, img_gt):
if (self.mode == "train"):
# 以50%的概率左右翻转
a = random.random()
if (a > 0.5):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img_gt = img_gt.transpose(Image.FLIP_LEFT_RIGHT)
# 以50%的概率上下翻转
a = random.random()
if (a > 0.5):
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_gt = img_gt.transpose(Image.FLIP_TOP_BOTTOM)
# 以50%的概率像素矩阵转置
a = random.random()
if (a > 0.5):
img = img.transpose(Image.TRANSPOSE)
img_gt = img_gt.transpose(Image.TRANSPOSE)
a = random.random()
# 进行随机裁剪
width, height = img.size
st = random.randint(0, 20)
box = (st, st, width - 1, height - 1)
img = img.crop(box)
img_gt = img_gt.crop(box)
img = img.resize((CROP, CROP))
img_gt = img_gt.resize((CROP, CROP))
img = self.im_tfs(img)
img_gt = np.array(img_gt)
img_gt = torch.from_numpy(img_gt)
"""
plt.subplot(1,2,1), plt.imshow(img.permute(1,2,0))
plt.subplot(1,2,2), plt.imshow(img_gt)
plt.show()
"""
return img, img_gt
def add_noise(self, img, gama=0.2):
noise = torch.randn(img.shape[0], img.shape[1], img.shape[2])
noise = noise * gama
img = img + noise
return img
# 重载getitem函数,使类可以迭代
def __getitem__(self, idx):
# idx = 100
img = Image.open(self.image_name[idx])
img_gt = Image.open(self.label_name[idx])
img, img_gt = self.data_process(img, img_gt)
# img = self.add_noise(img)
return img, img_gt
def __len__(self):
return len(self.image_name)
# 将mat数据转换成png
def mat2png(self, dataDir=None, outputDir=None):
if (dataDir == None):
dataroot = prefix + "cls/"
else:
dataroot = dataDir
if (outputDir == None):
outroot = prefix + "SegmentationClass/"
else:
outroot = outputDir
list_dir = os.listdir(dataroot)
for item in list_dir:
matimg = scio.loadmat(dataroot + item)
mattmp = matimg["GTcls"]["Segmentation"]
# 将mat转换成png
# print(mattmp[0][0])
new_im = Image.fromarray(mattmp[0][0])
print(outroot + item[:-4] + ".png")
new_im.save(outroot + item[:-4] + ".png")
'''
#标签文件的使用方法,需要先转换成numpy再变成tensor
img = Image.open(outroot + item[:-4] + ".png")
img = np.array(img)
img = torch.from_numpy(img)
print(img.shape)
plt.imshow(img)
plt.colorbar()
plt.show()
'''
if __name__ == "__main__":
data_train = PASCAL_BSD("train")
data_val = PASCAL_BSD("val")
train_data = torch.utils.data.DataLoader(data_train, batch_size=16, shuffle=True)
val_data = torch.utils.data.DataLoader(data_val, batch_size=16, shuffle=False)
for item in val_data:
img, img_gt = item
print(img.shape)
print(img_gt.shape)
参考文章