语义分割数据集制作(VOC2012为例)

本文介绍了如何基于李沐的《动手学深度学习》中的方法,利用PyTorch对VOC2012数据集进行处理,用于语义分割任务。文章详细展示了读取数据、随机裁剪、构建颜色映射到标签的映射以及封装数据集为类的过程。
摘要由CSDN通过智能技术生成

本文依据李沐大佬的《动手学深度学习》中的片段,以VOC2012数据集制作语义分割中使用的数据集。

 本文为语义分割任务中的前置工作,数据集制作中的简单实现,语义分割任务不同于图像分类任务,需要更精确的pixel级别的分类工作,因此数据集的处理时需要更多的准备。同时可以采用面向对象的方式去封装数据集,便于论文复现时,或是自己训练模型时使用。

读取VOC2012数据集(torchvision中下载)

import torchvision
import torch
import os
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
%matplotlib inline
# 首先使用数据集
def read_voc_image(voc_dir,is_train=True):
    # 读取所有数据并进行标注
    txt_fname = os.path.join(voc_dir,'ImageSets','Segmentation'
                             ,'train.txt' if is_train else 'val.txt')
    mode = torchvision.io.image.ImageReadMode.RGB
    with open(txt_fname,'r') as f:
        imgs = f.read().split()
    features,label = [],[]
    for i,fname in enumerate(imgs):
        features.append(torchvision.io.image.read_image(os.path.join(voc_dir
                            ,'JPEGImages',f'{fname}.jpg')))
        label.append(torchvision.io.image.read_image(os.path.join(voc_dir
                            ,'SegmentationClass',f'{fname}.png'),mode=mode))
    return features,label
def show_image(f,l,index=2,n=5):
    plt.figure(figsize=(16,16))
    for i in range(index,0,-1):
        for j in range(n):
            plt.subplot(index*n,i,j+1)
            plt.imshow(f[j].permute(1,2,0)) if not i==1 else plt.imshow(l[j].permute(1,2,0))
            # 展示的时候,维度需要由 C,H,W -> H,W,C
n = 5
features,labels = read_voc_image(r'D:\python_test\voc_data\VOCdevkit\VOC2012')
feature,label = features[:n],labels[:n]
show_image(feature,label)

随机裁剪 ->替代resize

# ========== 随机crop替代resize ==========
def voc_rand_crop(feature,label,height,width):
    rect = torchvision.transforms.RandomCrop.get_params(feature,(height,width))   # 得到裁剪参数
    # 以feature 为基准去裁剪
    feature = torchvision.transforms.functional.crop(feature,*rect)
    label = torchvision.transforms.functional.crop(label,*rect)
    return feature,label
feature,label = voc_rand_crop(features[0],labels[0],height=200,width=200)
plt.subplot(2,2,1)
plt.imshow(feature.permute(1,2,0))
plt.subplot(2,2,2)
plt.imshow(label.permute(1,2,0))

 构建colormap -> label的映射函数

# ========== 列举RGB颜色值以及类别名 ==========
VOC_colormap = [[0,0,0],[128,0,0],[0,128,0],[128,128,0],[0,0,128],[128,0,128]
                ,[0,128,128],[128,128,128],[64,0,0],[192,0,0],[64,128,0],[192,128,0]
                ,[64,0,128],[192,0,128],[64,128,128],[192,128,128],[0,64,0],[128,64,0]
                ,[0,192,0],[128,192,0],[0,64,128]]
VOC_classes = []
def voc_colormap2label():
    """构建RGB到VOC类别索引的映射"""
    colormap2label = torch.zeros(256 ** 3,dtype=torch.long)
    for i ,voc_map in enumerate(VOC_colormap):
        colormap2label[(voc_map[0] * 256 + voc_map[1]) * 256 +voc_map[2]] =  i
    return colormap2label
def voc_label_indices(colormap,colormap2label):
    """将VOC标签中的RGB值映射到它们的类别索引"""
    colormap = colormap.permute(1,2,0).numpy().astype('int32')    # 将tensor转化为numpy的int32格式
    idx = ((colormap[:,:,0]*256 + colormap[:,:,1]) * 256 + colormap[:,:,2])
    return colormap2label[idx]
y = voc_label_indices(features[20],voc_colormap2label())
print(y[105:115,100:140])

 封装成类

# ========== 自定义语义分割数据集 ==========
class voc_Seg_data:
    def __init__(self,voc_dir,crop_size,is_train=True):
        self.transform = torchvision.transforms.Normalize(mean=[0.485,0.456,0.406]
                    ,std=[0.229,0.224,0.225])
        self.crop_size = crop_size
        features,labels = read_voc_image(voc_dir=voc_dir,is_train=is_train)
        self.features = [self.normalize_image(img) for img in self.filter(imgs=features)]
        self.labels = self.filter(labels)
        self.voc_colormap2label = voc_colormap2label()
        print(f'read {voc_dir} of   ' + str(len(self.features)) + 'examples')
    def normalize_image(self,img):
        return self.transform(img.float() / 255)
    def filter(self,imgs):
        return [img for img in imgs if (img.shape[1] >= self.crop_size[0] and
                                        img.shape[2] >= self.crop_size[1])]
    def __getitem__(self, idx):
        feature,label = voc_rand_crop(self.features[idx],self.labels[idx],*self.crop_size)
        return feature,voc_label_indices(label, self.voc_colormap2label)

    def __len__(self):
        return len(self.features)

 加载运行

# ========== 读取数据集 ==========
crop_size = (320,320)
voc_train = voc_Seg_data(crop_size=crop_size,voc_dir=r'D:\python_test\voc_data\VOCdevkit\VOC2012'
                         ,is_train=True)
voc_val = voc_Seg_data(crop_size=crop_size,voc_dir=r'D:\python_test\voc_data\VOCdevkit\VOC2012'
                       ,is_train=False)
batch_size = 64
train_loader = DataLoader(batch_size=batch_size,dataset=voc_train,shuffle=True)

 代码文件使用jupyter notebook运行,加入了一些图片的显示等操作,实际使用中可以删除。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

藤宫博野

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值