利用COCO数据集制作自己的分割数据集

比如我要做分割卡车的数据集,coco里面正好有truck的标签,所以直接用,首先要安装cocoapi,我的思路是将整张原图和整张msk都分别保存,并将包围mask的最小box保存了,以防以后需要裁剪用。

代码:



from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab
import cv2
from shapely.geometry import Point
from matplotlib.patches import Polygon
pylab.rcParams['figure.figsize']=(8.0,10.0)
dataDir = 'home/public/coco/2014/coco'
dataType = 'val2014'
annFile = '{}/annotations/instances_{}.json'.format(dataDir,dataType)
coco = COCO(annFile)
cats = coco.loadCats(coco.getCatIds())
nms = [cat['name'] for cat in cats]
nms = set([cat['supercategory'] for cat in cats])

aaa = []
target = 'truck'
im_seq = 0
ss_Ids = coco.getCatIds(catNms=[target])
for i in ss_Ids:
    imgIds = coco.getImgIds(catIds=i)
    for img_id in imgIds:
        imgIds = coco.getImgIds(imgIds=img_id)
        img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
        I = io.imread('%s/images/%s/%s'%(dataDir,dataType,img['file_name']))
        annIds = coco.getAnnIds(imgIds=img['id'],catIds=i,iscrowd=False)
        heights,width = I.shape[0],I.shape[1]
        temp = np.ones(I.shape,np.uint8)*255
        for each_ann_id in annIds:
            anns = coco.loadAnns(each_ann_id)
            if(len(anns)!=0):
                im_seq = im_seq+1
                print('***************----{:09d}--->{}****************'.format(im_seq,target))
                fig = plt.figure()
                plt.imshow(temp)
                plt.axis('off')
                fig.set_size_inches(width/100.0,heights/100.0)
                plt.gca().xaxis.set_major_locator(plt.NullLocator())
                plt.gca().yaxis.set_major_locator(plt.NullLocator())
                plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
                plt.margins(0,0)

                x,y,w,h = coco.showAnns(anns)
                plt.savefig(target+'/msk/{:09d}.jpg'.format(im_seq))
                fig1 = plt.figure()
                plt.imshow(I)
                plt.axis('off')
                fig1.set_size_inches(width/100.0,heights/100.0)
                plt.gca().xaxis.set_major_locator(plt.NullLocator())
                plt.gca().yaxis.set_major_locator(plt.NullLocator())
                plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
                plt.margins(0,0)
                plt.savefig(target+'/ori/{:09d}.jpg'.format(im_seq))
                ff_file = open(target+'/box/{:09d}.txt'.format(im_seq),'w')
                ff_file.write(str(x))
                ff_file.write(',')
                ff_file.write(str(y))
                ff_file.write(',')
                ff_file.write(str(w))
                ff_file.write(',')
                ff_file.write(str(h))
                ff_file.write('\n')
                plt.close(fig)
                plt.close(fig1)

coco.py需要修改showAnns函数:

def showAnns(self, anns):
    if(len(anns)==0):
        return 0
    if('segmentation' in anns[0] or 'keypoints' in anns[0]):
        datasetType = 'instances'
    elseif('caption' in anns[0]):
        datasetType = 'captions'
    else:
        raise Exception('datasetType not supported')
    if(datasetType=='instances'):
        ax = plt.gca()
        as.set_autoscale_on(False)
        polygons = []
        color = []
        for ann in anns:
            c = (0,0,0)
            if('segmentation' in ann):
                if(type(ann['segmentation']) == list):
                    for seg in ann['segmentation']:
                        poly = np.array(seg).reshape((int(len(seg)/2),2))
                        polygons.append(Polygon(poly))
                        color.append(c)
                else:
                    t = self.imgs[ann['image_id']]
                    if(type(ann['segmentation']['counts'])==list):
                        rle = maskUtils.frPyObjects([ann['segmentation']],t['height'],t['width'])
                    else:
                        rle = [ann['segmentation']]
                    m = maskUtils.decode(rle)
                    img = np.ones((m.shape[0],m.shape[1],3))
                    if(ann['iscrowd']==1):
                        color_mask = np.array([2.0,166.0,101.0])/255
                    if(ann['iscrowd']==0):
                        color_mask = np.random.randint((1,3)).tolist()[0]
                    for i in range(3):
                        img[:,:,i] = color_mask[i]
                    ax.imshow(np.dstack((img,m*0.5)))
            if('keypoints' in ann and type(ann['keypoints'])==list):
                sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
                kp = np.array(ann['keypoints'])
                x = kp[0::3]
                y = kp[1::3]
                v = kp[2::3]
                for sk in sks:
                    if np.all(v[sk]>0):
                        plt.plot(x[sk],y[sk],linewidth=3,color=c)
                plt.plot(x[v>0],y[v>0],'o',markersize =8,markerfacecolor =c,markeredgecolor=c,markeredgewidth=2)
                plt.plot(x[v>1],y[v>1],'o',markersize =8,markerfacecolor =c,markeredgecolor=c,markeredgewidth=2)
        p = PatchCollection(polygons,facecolor=color,linewidth=0,alpha=0.4)
        ax.add_collection(p)
        p = PatchCollection(polygons,facecolor=color,edgecolors=color,linewidth=2)
        ax.add_collection(p)
    elseif(datasetType=='captions'):
        for ann in anns:
            print(ann['caption'])
    return ann['bbox']

这样以后,如果需要剪裁,可以直接通过读取.txt文件和相应的图像进行剪裁。如下:

import cv2
import os
target = 'truck'
root = '/home/46322/SomeDemos/coco/'+target #存放图像的文本的根目录
for i, j, k in os.walk(root):
    for each_msk in k:
        msk_path = i+each_msk
        box_path = root+'box/'+each_msk.split('.')[0]+'.txt'
        ori_path = root +'ori/'+each_msk
        bbox = []
        for bb in open(box_path):
            bbox.append(bb[:len(bb)-1])
        x,y,w,h = bbox[0].split(',')

        msk_im = cv2.imread(msk_path)
        ori_im = cv2.imread(ori_path)
        msk_cp = msk_im[int(float(y)):int(float(y))+int(float(h)),int(float(x))+int(float(w))]
        ori_cp = ori_im[int(float(y)):int(float(y))+int(float(h)),int(float(x))+int(float(w))]

        cv2.imwrite(root+'/msk_crop/'+each_msk,msk_cp)
        cv2.imwrite(root+'/ori_crop/'+each_msk,ori_cp)

 

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值