从COCO2017数据集中提取语义分割任务的图片及Mask

我的电脑上COCO2017数据集解压后的目录结构如下:

├─coco2017

        ├─annotations
        ├─train2017
        └─val2017 

我使用annotations中的instances_train2017.json和instances_val2017.json来提取训练集和验证集的图片以及对应的Mask。

代码是基于博主叶舟的博文修改的,博文地址是http://t.csdnimg.cn/gpFeg

一开始使用参考文章中的代码提取出数据集后,在训练的时候报了如下错误:

从报错信息上看出来是反向传播计算loss的时候出现了问题,然后我又切换到CPU上跑了下,然后报错:


说明是mask上的类别索引超过了模型输出的num_classes,这时候我才回过头检查数据集提取的代码,发现参考博文中的第34行代码在为不同的分割区域赋上不同的类别索引值之后将mask区域进行了累加,这样重叠部分的mask进行累加会导致类别索引溢出。

 

然后简单进行了修改,修改后将所有mask区域依次放到一个空的底版上,而且每个mask都会和底版上的元素按位进行比较(代码:mask = np.maximum(mask, ann_mask)),只保留值大的元素,这样做的好处是避免了重叠部分索引值的累加,缺点是大物体上如果有小物体,而且小物体的索引值小于大物体的话,小物体的mask将不会被提取出来。这里如果有更好的处理方法的话,请告诉我,感谢!

修改后提取数据集的代码:

import os
import numpy as np
from pycocotools.coco import COCO
from PIL import Image
import imgviz
import tqdm
import argparse
import shutil


def main(args):
    # 初始化COCO api
    annFile = '{}/annotations/instances_{}.json'.format(args.dataDir, args.split)
    os.makedirs(os.path.join(args.dataDir, 'SegmentationClass', args.data_type), exist_ok=True)
    os.makedirs(os.path.join(args.dataDir, 'JPEGImages', args.data_type), exist_ok=True)
    coco = COCO(annFile)

    # 获取所有图像的ID
    imgIds = coco.getImgIds()

    # 对于每一张图像
    for imgId in tqdm.tqdm(imgIds, ncols=100):
        # 获取图像信息
        img_info = coco.loadImgs(imgId)[0]
        img_name = img_info['file_name']

        # 获取该图像的所有注释ID
        annIds = coco.getAnnIds(imgIds=img_info['id'], iscrowd=None)
        anns = coco.loadAnns(annIds)

        # 创建一个空的掩码
        mask = np.zeros((img_info['height'], img_info['width']))

        # 对于每一个注释
        if len(annIds) > 0:
            for ann in anns:
                # 获取类别ID
                catId = ann['category_id']
                # 获取该注释的掩码
                ann_mask = coco.annToMask(ann) * catId
                # 将注释的掩码添加到总掩码上
                # np.maximum(mask, ann_mask):mask和ann_mask按元素进行对比,保留较大的值
                # 这样做的好处是避免了重叠物体导致的类别索引值溢出
                # 缺点是导致了大物体上的小物体如果类别索引小于小物体,则不会在mask上标注出来
                mask = np.maximum(mask, ann_mask)

            # 将掩码转换为图像
            mask_img = Image.fromarray(mask.astype(np.uint8), mode="P")

            # 将图像转换为调色板模式
            colormap = imgviz.label_colormap()
            mask_img.putpalette(colormap.flatten())

            # 保存图像和对应的掩码
            img_origin_path = os.path.join(args.dataDir, args.split, img_name)
            img_output_path = os.path.join(args.dataDir, 'test', 'JPEGImages', args.data_type, img_name)
            seg_output_path = os.path.join(args.dataDir, 'test', 'SegmentationClass', args.data_type,
                                           img_name.replace('.jpg', '.png'))
            shutil.copy(img_origin_path, img_output_path)
            mask_img.save(seg_output_path)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataDir", default="D:/data/coco2017", type=str,
                        help="input dataset directory")
    parser.add_argument("--split", default="train2017", type=str,
                        help="train2017 or val2017")
    parser.add_argument("--data_type", default="train", type=str,
                        help="train or val")
    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()
    main(args)

下面是封装数据集的代码(参考了博主太阳花的小绿豆的代码):

# 导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import transforms as T

# torch.manual_seed(17)


# 自定义数据集CamVidDataset
class COCODataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)
    """

    def __init__(self, images_dir, masks_dir,transforms=None):

        self.ids = os.listdir(images_dir)
        self.ids2 = os.listdir(masks_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids2]
        self.transforms = transforms

    def __getitem__(self, i):
        # read data
        image = Image.open(self.images_fps[i]).convert('RGB')
        # 标签本身是单通道的灰度图,这里用convert('RGB')转换成了三通道图片,只不过三个通道都是相同的
        mask = Image.open(self.masks_fps[i])
        # mask = np.asarray(Image.open(self.masks_fps[i]), dtype=np.int32)
        if self.transforms is not None:
            image, mask = self.transforms(image, mask)

        return image, mask  # 标签图像只返回一个通道就行了

    def __len__(self):
        return len(self.ids)

    @staticmethod
    # collate_fn是用于整理数据的函数,pytorch有默认的collate_fn,其中的参数batch指的是用DataLoader封装时的的batch_size
    def collate_fn(batch):
        # batch中包含了图像和对应的标签
        # 因为batch本身是一个元组,里面的每一个元素是由图像和其标签组成的元组,*解包操作将整个列表解开了,此时总共有batch_size个元组
        # 然后再进行zip操作,zip操作将将每个元素中的第一个数据即图像数据取出放在第一个元组中,将每个元素中的第二个数据取出放在另一个元组中
        # 然后再用list将其转换成列表,此时列表中有两个元素,每个元素都是一个元组,且两个元素的长度相同,第一个元素里是所有的图像,第二个元素
        # 是所有的标签
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fill_value=0)  # 所有图片贴到一个背景全为0的mask上
        batched_targets = cat_list(targets, fill_value=255)  # 所有标签贴到一个背景全为1的mask上
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    # images表示放在列表中的一个batch的所有图像或者标签,图像是张量的形式
    # 分别计算该batch数据中channel, h, w的最大值
    # zip()返回的是一个由元组构成的列表
    # [(channels),(heights),(weights)]
    # max_size是由最大的channel、height、weight构成的元组
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    # 如len(images)=8,max_size=(3,128,256)
    # (8,)+(3,128,256)=(8,3,128,256)
    batch_shape = (len(images),) + max_size  # c,h,w最大的那个batch
    # images[0].new():创建一个新的张量,该张量和images[0]的type和device一样,但是没有内容
    # images[0].new(*batch_shape):创建一个新的张量,该张量和images[0]的type和device一样,且形状和batch_shape一致
    # images[0].new(*batch_shape).fill_(fill_value):用指定的值填充新创建的张量.
    # 该行代码的意思是使用第一张图像创建一张底版,这张底版的尺寸是(b,maxChannel,maxH,maxW)
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        # 将一个batch中的每张图象和对应的底版放到一个元组中
        # ...表示对应的维度全部选取
        # pad_img按照指定值填充的形状为(b,maxChannel,MaxH,maxW)的张量,将它作为一张mask
        # 把每张图像copy到这张mask上,以实现所有的图片形状一致
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
    return batched_imgs


if __name__ == '__main__':
    # 设置数据集和标签路径
    x_train_dir = r"D:\data\coco2017\test\JPEGImages\train"
    y_train_dir = r"D:\data\coco2017\test\SegmentationClass\train"  # train label

    x_valid_dir = r"D:\data\coco2017\test\JPEGImages\val"
    y_valid_dir = r"D:\data\coco2017\test\SegmentationClass\val"  # val label
    transforms = T.Compose([
            T.RandomResize(768, 1024),
            T.CenterCrop((768,768)),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

    train_dataset = COCODataset(
        x_train_dir,
        y_train_dir,
        transforms=transforms
    )
    val_dataset = COCODataset(
        x_valid_dir,
        y_valid_dir,
        transforms=transforms
    )

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,collate_fn=train_dataset.collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,collate_fn=val_dataset.collate_fn)

    for index, (img, label) in enumerate(train_loader):
        print(img.shape)  # img和label都是tensor
        print(label.shape)

        plt.figure(figsize=(10, 10))
        plt.subplot(221)
        # moveaxis(0, 2):将0轴挪到2轴的位置,相当于将CHW调成HWC
        plt.imshow(img[0, :, :, :].moveaxis(0, 2))
        plt.subplot(222)
        plt.imshow(label[0, :, :])

        plt.subplot(223)
        plt.imshow((img[6, :, :, :].moveaxis(0, 2)))
        plt.subplot(224)
        plt.imshow(label[6, :, :])
        plt.show()
        if index == 0:
            break

需要注意的是:instance_train/val2017.json文件中标注的类别数虽然是80,但是类别ID仍然是按照stuff_train/val2017.json中的类别ID排列的,从1到90,考虑到背景,我将模型中的num_classes设为了91,损失函数使用的是pytorch官方的CrossEntropy,同时设置了ignore_index=255。

 这次为了使用COCO数据集进行语义分割,在数据集的提取部分折腾了好几天,今天程序总算是跑起来了,将数据集的处理过程记录下来,由于本人的代码能力比较烂,如果哪里存在问题了请一定告诉我,非常感谢!

最后再次感谢博主叶舟以及博主太阳花的小绿豆!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值