本文依据李沐大佬的《动手学深度学习》中的片段,以VOC2012数据集制作语义分割中使用的数据集。
本文为语义分割任务中的前置工作,数据集制作中的简单实现,语义分割任务不同于图像分类任务,需要更精确的pixel级别的分类工作,因此数据集的处理时需要更多的准备。同时可以采用面向对象的方式去封装数据集,便于论文复现时,或是自己训练模型时使用。
读取VOC2012数据集(torchvision中下载)
import torchvision
import torch
import os
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
%matplotlib inline
# 首先使用数据集
def read_voc_image(voc_dir,is_train=True):
# 读取所有数据并进行标注
txt_fname = os.path.join(voc_dir,'ImageSets','Segmentation'
,'train.txt' if is_train else 'val.txt')
mode = torchvision.io.image.ImageReadMode.RGB
with open(txt_fname,'r') as f:
imgs = f.read().split()
features,label = [],[]
for i,fname in enumerate(imgs):
features.append(torchvision.io.image.read_image(os.path.join(voc_dir
,'JPEGImages',f'{fname}.jpg')))
label.append(torchvision.io.image.read_image(os.path.join(voc_dir
,'SegmentationClass',f'{fname}.png'),mode=mode))
return features,label
def show_image(f,l,index=2,n=5):
plt.figure(figsize=(16,16))
for i in range(index,0,-1):
for j in range(n):
plt.subplot(index*n,i,j+1)
plt.imshow(f[j].permute(1,2,0)) if not i==1 else plt.imshow(l[j].permute(1,2,0))
# 展示的时候,维度需要由 C,H,W -> H,W,C
n = 5
features,labels = read_voc_image(r'D:\python_test\voc_data\VOCdevkit\VOC2012')
feature,label = features[:n],labels[:n]
show_image(feature,label)
随机裁剪 ->替代resize
# ========== 随机crop替代resize ==========
def voc_rand_crop(feature,label,height,width):
rect = torchvision.transforms.RandomCrop.get_params(feature,(height,width)) # 得到裁剪参数
# 以feature 为基准去裁剪
feature = torchvision.transforms.functional.crop(feature,*rect)
label = torchvision.transforms.functional.crop(label,*rect)
return feature,label
feature,label = voc_rand_crop(features[0],labels[0],height=200,width=200)
plt.subplot(2,2,1)
plt.imshow(feature.permute(1,2,0))
plt.subplot(2,2,2)
plt.imshow(label.permute(1,2,0))
构建colormap -> label的映射函数
# ========== 列举RGB颜色值以及类别名 ==========
VOC_colormap = [[0,0,0],[128,0,0],[0,128,0],[128,128,0],[0,0,128],[128,0,128]
,[0,128,128],[128,128,128],[64,0,0],[192,0,0],[64,128,0],[192,128,0]
,[64,0,128],[192,0,128],[64,128,128],[192,128,128],[0,64,0],[128,64,0]
,[0,192,0],[128,192,0],[0,64,128]]
VOC_classes = []
def voc_colormap2label():
"""构建RGB到VOC类别索引的映射"""
colormap2label = torch.zeros(256 ** 3,dtype=torch.long)
for i ,voc_map in enumerate(VOC_colormap):
colormap2label[(voc_map[0] * 256 + voc_map[1]) * 256 +voc_map[2]] = i
return colormap2label
def voc_label_indices(colormap,colormap2label):
"""将VOC标签中的RGB值映射到它们的类别索引"""
colormap = colormap.permute(1,2,0).numpy().astype('int32') # 将tensor转化为numpy的int32格式
idx = ((colormap[:,:,0]*256 + colormap[:,:,1]) * 256 + colormap[:,:,2])
return colormap2label[idx]
y = voc_label_indices(features[20],voc_colormap2label())
print(y[105:115,100:140])
封装成类
# ========== 自定义语义分割数据集 ==========
class voc_Seg_data:
def __init__(self,voc_dir,crop_size,is_train=True):
self.transform = torchvision.transforms.Normalize(mean=[0.485,0.456,0.406]
,std=[0.229,0.224,0.225])
self.crop_size = crop_size
features,labels = read_voc_image(voc_dir=voc_dir,is_train=is_train)
self.features = [self.normalize_image(img) for img in self.filter(imgs=features)]
self.labels = self.filter(labels)
self.voc_colormap2label = voc_colormap2label()
print(f'read {voc_dir} of ' + str(len(self.features)) + 'examples')
def normalize_image(self,img):
return self.transform(img.float() / 255)
def filter(self,imgs):
return [img for img in imgs if (img.shape[1] >= self.crop_size[0] and
img.shape[2] >= self.crop_size[1])]
def __getitem__(self, idx):
feature,label = voc_rand_crop(self.features[idx],self.labels[idx],*self.crop_size)
return feature,voc_label_indices(label, self.voc_colormap2label)
def __len__(self):
return len(self.features)
加载运行
# ========== 读取数据集 ==========
crop_size = (320,320)
voc_train = voc_Seg_data(crop_size=crop_size,voc_dir=r'D:\python_test\voc_data\VOCdevkit\VOC2012'
,is_train=True)
voc_val = voc_Seg_data(crop_size=crop_size,voc_dir=r'D:\python_test\voc_data\VOCdevkit\VOC2012'
,is_train=False)
batch_size = 64
train_loader = DataLoader(batch_size=batch_size,dataset=voc_train,shuffle=True)
代码文件使用jupyter notebook运行,加入了一些图片的显示等操作,实际使用中可以删除。