最终版之-适用于mmdetection2.3-2.6,关于VOC2012数据集转coco的格式(主要针对分割任务)升华版

适用于mmdetection2.3-2.6,关于VOC2012数据集转coco的格式(主要针对分割任务)最终版

改bug不易,关于博主调好的json能直接用于mmdection的,要的话底下评论联系哦

源码参考了该github但是需要按本文修改,因为无法直接运行得到能用的结果,转不了voc2012的数据集。

码字不易,给个赞哟


前言

本文是接着处理,VOC to COCO 适合于分割任务的转换,承接上一文,
(适用于mmdetection2.3-2.6,关于VOC2012数据集转coco的格式(主要针对分割任务))
的改进版,在上一篇博文中, 同样实现了这一步骤并且适用于分割任务,但是分割效果是非常的不行,因为提取的分割部分的像素位置信息,是针对目标所在处的整个矩形方框四个顶点的位置信息,所以分割出来的并非目标的轮廓。本文将展现如何精确提取mask的像素信息,用于分割任务。

一、先上代码

代码如下(示例):

import argparse
import json
import matplotlib.pyplot as plt
import skimage.io as io
import cv2
# from labelme import utils
import numpy as np
import glob

import PIL.Image
import sys
import os
import os.path as osp
import mmcv
from tqdm import tqdm
import PIL


def txt2list(txtfile):
    fsa = open(txtfile)
    ls = []
    for line in fsa:
        ls.append(line[:-1])
    return ls


lists = txt2list('data/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt')
ClASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
               'tvmonitor']

def object_classes():#这里定义了自己的数据集的目标类别
    return ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
               'tvmonitor']


label_ids = {name: i + 1 for i, name in enumerate(object_classes())}

categoriess = []
for k,v in label_ids.items():
    categoriess.append({"name": k, "id": v})


class PascalVOC2coco(object):
    def __init__(self, xml=[], save_json_path='./new.json'):
        '''
        :param xml: 所有Pascal VOC的xml文件路径组成的列表
        :param save_json_path: json保存位置
        '''
        self.xml = xml
        self.save_json_path = save_json_path
        self.images = []
        self.categories = categoriess

        self.annotations = []
        # self.data_coco = {}
        self.label = []
        self.annID = 1
        self.height = 0
        self.width = 0
        self.ob = []

        self.save_json()

    def data_transfer(self):
        for num, json_file in enumerate(self.xml):

            # 进度输出
            sys.stdout.write('\r>> Converting image %d/%d' % (
                num + 1, len(self.xml)))
            sys.stdout.flush()

            self.json_file = json_file
            self.num = num
            path = os.path.dirname(self.json_file)
            path = os.path.dirname(path)
            # path=os.path.split(self.json_file)[0]
            # path=os.path.split(path)[0]
            obj_path = glob.glob(os.path.join(path, 'SegmentationObject', '*.png'))
            with open(json_file, 'r') as fp:
                flag = 0
                xxx=[]
                for p in fp:
                    xxx.append(p)
                for p in xxx:
                    f_name = 1
                    # if 'folder' in p:
                    #     folder =p.split('>')[1].split('<')[0]
                    if 'filename' in p:
                        self.filen_ame = p.split('>')[1].split('<')[0]
                        f_name = 0
                        self.path = os.path.join(path, 'SegmentationObject', self.filen_ame.split('.')[0] + '.png')
                        # if self.path not in obj_path:
                        w, h=PIL.Image.open(self.path).size
                        self.width = int(w)
                        self.height = int(h)
                        self.images.append(self.image())
                        print(self.filen_ame)
                        
                    # print(num)
                    # if 'width' in p:
                    #     self.width = int(p.split('>')[1].split('<')[0])

                    # if 'height' in p:
                    #     self.height = int(p.split('>')[1].split('<')[0])
                    
                    if flag == 1:
                        if self.ob[0]not in ClASSES:
                            self.ob = []
                            flag = 0
                            continue
                        else:
                            self.supercategory = self.ob[0]   ## 除去类别不再CLASSES中的部分!!!!
                            # if self.supercategory not in self.label:
                            #     self.categories.append(self.categorie())
                            #     self.label.append(self.supercategory)
                            # 边界框
                            if int(self.filen_ame.split('.')[0].split('_')[0]) < 2009:
                                x1 = int(self.ob[1]);
                                y1 = int(self.ob[2]);
                                x2 = int(self.ob[3]);
                                y2 = int(self.ob[4])
                                # print(self.ob)
                                self.rectangle = [x1, y1, x2, y2]
                        # print(self.rectangle)
                                self.bbox = [x1, y1, x2 - x1+1, y2 - y1+1]  # COCO 对应格式[x,y,w,h]
                            elif int(self.filen_ame.split('.')[0].split('_')[0]) >= 2009:
                                x1 = int(self.ob[2]);
                                y1 = int(self.ob[4]);
                                x2 = int(self.ob[1]);
                                y2 = int(self.ob[3])
                            # print(self.ob)
                                self.rectangle = [x1, y1, x2, y2]
                            # print(self.rectangle)
                                self.bbox = [x1, y1, x2 - x1, y2 - y1]  # COCO 对应格式[x,y,w,h]
                            self.annotations.append(self.annotation())
                            self.ob = []
                            self.annID += 1
                            flag = 0
                    elif f_name == 1:
                        if 'name' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                        if 'xmin' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                        if 'ymin' in p:
                            # print('ok')
                            self.ob.append(p.split('>')[1].split('<')[0])
                            # print(p.split('>')[1].split('<')[0])
                        # print(p)
                        if 'xmax' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                        if 'ymax' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                            # print(self.ob)
                        if len(self.ob) > 4:
                            flag = 1
                            
        sys.stdout.write('\n')
        sys.stdout.flush()

    def image(self):
        image = {}
        image['file_name'] = self.filen_ame
        image['height'] = self.height
        image['width'] = self.width
        image['id'] = self.num + 1
        return image

    def categorie(self):
        categorie = {}
        categorie['supercategory'] = self.supercategory
        categorie['name'] = self.supercategory
        categorie['id'] = len(self.label) + 1  # 0 默认为背景
        return categorie

    @staticmethod
    def change_format(contour):
        contour2 = []
        length = len(contour)
        # print(length)
        for i in range(0, length, 2):
            contour2.append([contour[i], contour[i + 1]])
        return np.asarray(contour2, np.int32)

    def annotation(self):
        annotation = {}
        # annotation['segmentation'] = [self.getsegmentation()]
        annotation['segmentation'] = [list(map(float, self.getsegmentation()))]
        # print(annotation['segmentation'])
        annotation['iscrowd'] = 0
        annotation['image_id'] = self.num + 1
        contour = PascalVOC2coco.change_format(annotation['segmentation'][0])
        # print(annotation['segmentation'][0])
        annotation['area'] = abs(float(cv2.contourArea(contour, True)))
        # annotation['bbox'] = list(map(float, self.bbox))
        annotation['bbox'] = self.bbox
        annotation['category_id'] = self.getcatid(self.supercategory)
        annotation['id'] = self.annID

        # 计算轮廓面积
        # print(len(annotation['segmentation'][0]))
        
        # print(annotation['area'])
        return annotation

    def getcatid(self, label):
        for categorie in self.categories:
            if label == categorie['name']:
                return categorie['id']
        return -1

    def getsegmentation(self):

        try:
            # print(self.path)
            mask_1 = cv2.imread(self.path, 0)
            
            mask = np.zeros_like(mask_1, np.uint8)
            rectangle = self.rectangle
            # print(mask_1)
            mask[rectangle[1]:rectangle[3], rectangle[0]:rectangle[2]] = mask_1[rectangle[1]:rectangle[3], rectangle[0]:rectangle[2]]

            # 计算矩形中点像素值
            mean_x = (rectangle[0] + rectangle[2]) // 2
            mean_y = (rectangle[1] + rectangle[3]) // 2

            end = min((mask.shape[1], int(rectangle[2]) + 1))
            start = max((0, int(rectangle[0]) - 1))
            # print(start, end, mean_x, mean_y)
            flag = True
            for i in range(mean_x, end):
                x_ = i
                y_ = mean_y
                pixels = mask_1[y_, x_]
                # print(pixels)
                if pixels != 0 and pixels != 220:  # 0 对应背景 220对应边界线
                    # print(pixels)
                    mask = (mask == pixels).astype(np.uint8)
                    # for i in mask:
                    #     print(i)
                    flag = False
                    break
            if flag:
                for i in range(mean_x, start+1, -1):
                    x_ = i
                    y_ = mean_y
                    pixels = mask_1[y_, x_]
                    if pixels != 0 and pixels != 220:
                        mask = (mask == pixels).astype(np.uint8)
                        # print(y_, x_)
                        break
            self.mask = mask
            return self.mask2polygons()

        except:
            print('no')
            return [0]

    def mask2polygons(self):
        '''从mask提取边界点'''
        contours = cv2.findContours(self.mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)  # 找到轮廓线
        # print(contours[0])
        bbox=[]
        for cont in contours[0]:
            # print(cont.flatten())
            [bbox.append(i) for i in list(cont.flatten())]
            # map(bbox.append,list(cont.flatten()))
        # print(len(bbox))
        return bbox  # list(contours[1][0].flatten())
        
    def data2coco(self):
        data_coco = {}
        data_coco['images'] = self.images
        data_coco['annotations'] = self.annotations
        # categories = []
        # for k,v in label_ids.items():
        #     categories.append({"name": k, "id": v})
        data_coco['categories'] = self.categories
        return data_coco

    def save_json(self):
        self.data_transfer()
        self.data_coco = self.data2coco()
        # 保存json文件
        json.dump(self.data_coco, open(self.save_json_path, 'w'),indent=4)  # indent=4 更加美观显示
        # mmcv.dump(self.data_coco, self.save_json_path,indent=4)


path = '/media/gt/新加卷/mmdetect/mmdetection/data/VOCdevkit/VOC2012/Annotations/'

files= os.listdir(path) #得到文件夹下的所有文件名称

xml_files=[]
for i in files:
    # print(i)
    if i.split('.')[0] in lists:
        xml_files.append(path+i)
# print(xml_files)
# xml_files=['data/VOCdevkit/VOC2012/Annotations/2009_000012.xml']
PascalVOC2coco(xml_files, './data/coco/annotations/trains.json')

二、使用步骤

将对应位置修改为自己的VOC路径即可

三、关键部分代码说明

代码如下(示例):

部分1

下面这部分代码主要实现了读取一个xml文件,然后实现对每一行信息进行读取。
具体解释看所写的注释

with open(json_file, 'r') as fp:
                flag = 0   # 标志位,用于分隔,每次读取完一个xml后会根据是否为1进入json内容的编写
                xxx=[]
                for p in fp:
                    xxx.append(p)
                for p in xxx:
                    f_name = 1
                    # if 'folder' in p:
                    #     folder =p.split('>')[1].split('<')[0]
                    if 'filename' in p:
                        self.filen_ame = p.split('>')[1].split('<')[0] # 获取xml文件中的filename,并且保存
                        f_name = 0  # 读取标记位
                        self.path = os.path.join(path, 'SegmentationObject', self.filen_ame.split('.')[0] + '.png') 
                        # 获取对应的mask的路径
                        # if self.path not in obj_path:
                        w, h=PIL.Image.open(self.path).size    #  获取mask的宽高信息!!在这里获取很重要
                        # 而不是通过后面的 if 'width' in p:部分获取
                        self.width = int(w)
                        self.height = int(h)
                        self.images.append(self.image()) ## 每次读取一个xml都能获取一个对应的字典self.image()
                        # 用于写入最终的json文件
                        print(self.filen_ame)
                        #     break
                    # print(num)
                    # if 'width' in p:
                    #     self.width = int(p.split('>')[1].split('<')[0])

                    # if 'height' in p:
                    #     self.height = int(p.split('>')[1].split('<')[0])
                    # 不用上述获取宽高信息,是因为博主尝试的过程中发现,虽然能提取一样的信息,
                    #但是通过这个个操作提取信息,mmdetection会报错,有一个cv2.copyMakeBorder 会报错,
                    #显示输入的img存在长宽高不符合输入要求。具体为什么没有细纠,
                    #应该是xml文件对应的宽高在读取时可能会出现偏差与实际的对应图片宽高。

                    if flag == 1:
                        if self.ob[0]not in ClASSES:   ## VOC2012不同于2007,包含着很多东西
                        # self.ob[0]保存的是类别,VOC2012中存在人体关键点信息如head等,这不属于分割检测的20类目标
                        #因此通过判断跳过 head hand等类别
                            self.ob = []
                            flag = 0
                            continue
                        else:
                            self.supercategory = self.ob[0]   ## 除去类别不再CLASSES中的部分!!!!
                            # if self.supercategory not in self.label:
                            #     self.categories.append(self.categorie())
                            #     self.label.append(self.supercategory)
                            # 边界框
                            if int(self.filen_ame.split('.')[0].split('_')[0]) < 2009:   
                            # 	VOC2012 中的xml根据年份不同,主要类别的xmin  ymin等四个信息所在行不一样,所以要分别对待
                            #	 不然会导致定位错误,目标框信息获取不对,2007-2008是一个格式,2009 以后是一个格式
                            #   注意:这里针对的是20类目标,不包括关键点head这些
                            #    因为这些部分的xy信息顺序在2009年及其以后的是有大差别的
                                x1 = int(self.ob[1]);
                                y1 = int(self.ob[2]);
                                x2 = int(self.ob[3]);
                                y2 = int(self.ob[4])
                                # print(self.ob)
                                self.rectangle = [x1, y1, x2, y2]
                        # print(self.rectangle)
                                self.bbox = [x1, y1, x2 - x1+1, y2 - y1+1]  # COCO 对应格式[x,y,w,h]
                            elif int(self.filen_ame.split('.')[0].split('_')[0]) >= 2009:
                                x1 = int(self.ob[2]);
                                y1 = int(self.ob[4]);
                                x2 = int(self.ob[1]);
                                y2 = int(self.ob[3])
                            # print(self.ob)
                                self.rectangle = [x1, y1, x2, y2]
                            # print(self.rectangle)
                                self.bbox = [x1, y1, x2 - x1, y2 - y1]  # COCO 对应格式[x,y,w,h]
                            self.annotations.append(self.annotation()) # 写入相关信息用于保存
                            self.ob = []   #每次成功执行后要重新清空
                            self.annID += 1 ##id+1计数
                            flag = 0
                    elif f_name == 1:
                        if 'name' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                        if 'xmin' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                        if 'ymin' in p:
                            # print('ok')
                            self.ob.append(p.split('>')[1].split('<')[0])
                            # print(p.split('>')[1].split('<')[0])
                        # print(p)
                        if 'xmax' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                        if 'ymax' in p:
                            self.ob.append(p.split('>')[1].split('<')[0])
                            # print(self.ob)
                        if len(self.ob) > 4:
                            flag = 1

部分2

获取关键的mask信息,也就是为了获得segmentation 字段的信息,每一个它要包含mask目标对应区域的全部像素点位置信息

# 下面的这个函数用于获取mask中属于目标区域内部的信息,这是一种多边形的方法,
# 因此区别于上一文章中只会分割矩形区域的缺点
    def getsegmentation(self):
        try:
            # print(self.path)
            mask_1 = cv2.imread(self.path, 0)  #  读入对应xml文件的mask文件
            mask = np.zeros_like(mask_1, np.uint8)  # 创造一个0矩阵,按照mask的格式
            rectangle = self.rectangle # 获取每个xml中得到的目标框四个顶点位置信息
            # print(mask_1)
            mask[rectangle[1]:rectangle[3], rectangle[0]:rectangle[2]] = mask_1[rectangle[1]:rectangle[3], rectangle[0]:rectangle[2]]  # 将读取的mask_1中对应目标框区域信息赋予创造的0矩阵

            # 计算矩形中点像素值, 也就是当作目标对象的中心
            mean_x = (rectangle[0] + rectangle[2]) // 2
            mean_y = (rectangle[1] + rectangle[3]) // 2

            end = min((mask.shape[1], int(rectangle[2]) + 1))   # 获取目标区域的边界点,mask.shape[1]代表weight
            start = max((0, int(rectangle[0]) - 1)) # 对应边界起始点,和end相当于一条线左右两边的端点
            # print(start, end, mean_x, mean_y)
            flag = True #标位  
            for i in range(mean_x, end):
                x_ = i
                y_ = mean_y
                pixels = mask_1[y_, x_]  ##  获取整个mask_1每个点的像素值
                # print(pixels)
                if pixels != 0 and pixels != 220:  # 0 对应背景 220对应边界线,
                # 因为每个mask中目标趋于内部的像素具有一致性,如pixels = 38,
                    # print(pixels) #  那么此部分所有操作就会将所有值为38的点标记
                    mask = (mask == pixels).astype(np.uint8)  #因此在这通过判断就可以在创建号的mask中 将所有属于目标
                    #  的像素全部标记
                    # for i in mask:
                    #     print(i)
                    flag = False   
                    break
            if flag:    #  如果向end方向不存在,则从中心向start递减
                for i in range(mean_x, start+1, -1):
                    x_ = i
                    y_ = mean_y
                    pixels = mask_1[y_, x_]
                    if pixels != 0 and pixels != 220:
                        mask = (mask == pixels).astype(np.uint8)
                        # print(y_, x_)
                        break
            self.mask = mask  #标记好的mask赋予 self.mask
            return self.mask2polygons() # 通过self.mask 以及self.mask2polygons()函数提取所有边界点

        except:  # 判断前面的try是否正确执行,里面的内容是否出错
            print('no')
            return [0]

    def mask2polygons(self):
        '''从mask提取边界点'''
        contours = cv2.findContours(self.mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)  # 找到轮廓线
        # print(contours[0])  
        #注意 cv2.findContours只能返回两个值,一个是找到所有的边界点值,另一个是不同目标之间的关系程度
        # 这与网上很多版本的写法不一样,大多数都返回了3个值,是错误的
        #我们需要获取 contours[0]的信息也就是所有的目标边界点的x,y
        bbox=[]
        for cont in contours[0]:
            # print(cont.flatten())  # 每个cont 都是一个包含一系列xy信息的多个[[y x]]如[[[234 106]] [[232 108]]]
            [bbox.append(i) for i in list(cont.flatten())]  #  这一部就是先把一个cont拉值为[224,106,232,108]
            # 然后取出每一个拉直后的cont中的所有元素,重新添加到bbox列表中
            # map(bbox.append,list(cont.flatten()))
        # print(len(bbox))
        return bbox  # list(contours[1][0].flatten())

部分3

这一部分比较简单,具体讲解就不多了,就写了部分注释

 def change_format(contour):   # 用于对获取到的所有边界点作处理,用于计算整个目标区域这个多边形的面积
        contour2 = []
        length = len(contour)
        # print(length)
        for i in range(0, length, 2):
            contour2.append([contour[i], contour[i + 1]])  #重新构建列表,里面每个元素代表一个点的x,y
        return np.asarray(contour2, np.int32)

    def annotation(self):
        annotation = {}
        # annotation['segmentation'] = [self.getsegmentation()]
        annotation['segmentation'] = [list(map(float, self.getsegmentation()))]
        #  [list(map(float, self.getsegmentation()))] 以列表的形式存储我们前面得到的所有分割出来的点的位置信息
        # print(annotation['segmentation'])
        annotation['iscrowd'] = 0
        annotation['image_id'] = self.num + 1
        contour = PascalVOC2coco.change_format(annotation['segmentation'][0])  # 获取轮廓的点
        # print(annotation['segmentation'][0])
        annotation['area'] = abs(float(cv2.contourArea(contour, True)))  # 计算轮廓面积
        # annotation['bbox'] = list(map(float, self.bbox))
        annotation['bbox'] = self.bbox
        annotation['category_id'] = self.getcatid(self.supercategory)
        annotation['id'] = self.annID
        # print(len(annotation['segmentation'][0]))
        # print(annotation['area'])
        return annotation

通过下述代码可以检测以下我们生成的json文件是否分割出了正确的mask信息(检测后还是有些不准,转换后提取分割的结果,有机会在细改吧)

提示:关于这则代码的使用,请把他放在cocoapi中的pycocotools文件夹里面
cocoapi可从github自行下载

# -*- coding:utf-8 -*-
from __future__ import print_function
from pycocotools.coco import COCO
import os, sys, zipfile
import urllib.request
import shutil
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab

pylab.rcParams['figure.figsize'] = (8.0, 10.0)
annFile='data/coco/annotations/newtrain.json'
coco=COCO(annFile)

# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
nms=[cat['name'] for cat in cats]
print('COCO categories: \n{}\n'.format(' '.join(nms)))

# nms = set([cat['supercategory'] for cat in cats])
# print('COCO supercategories: \n{}'.format(' '.join(nms)))

# imgIds = coco.getImgIds(imgIds = [324158])
imgIds = coco.getImgIds()
imgId=np.random.randint(0, len(imgIds))
img = coco.loadImgs(imgIds[imgId])[0]
dataDir = '.'
dataType = 'data/coco/train2017'
I = io.imread('%s/%s/%s'%(dataDir, dataType, img['file_name']))
# I = io.imread('%s/%s'%(dataDir,img['file_name']))

plt.axis('off')
plt.imshow(I)
plt.show()


# load and display instance annotations
# 加载实例掩膜
# catIds = coco.getCatIds(catNms=['person','dog','skateboard']);
# catIds=coco.getCatIds()
catIds=[]
for ann in coco.dataset['annotations']:
    if ann['image_id']==imgIds[imgId]:
        catIds.append(ann['category_id'])

plt.imshow(I); plt.axis('off')
annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)
coco.showAnns(anns)
plt.show()

对应位置进行修改,即可。
成功后的图可见,比较可见,其实提取的效果在这里插入图片描述在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值