比如我要做分割卡车的数据集,coco里面正好有truck的标签,所以直接用,首先要安装cocoapi,我的思路是将整张原图和整张msk都分别保存,并将包围mask的最小box保存了,以防以后需要裁剪用。
代码:
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab
import cv2
from shapely.geometry import Point
from matplotlib.patches import Polygon
pylab.rcParams['figure.figsize']=(8.0,10.0)
dataDir = 'home/public/coco/2014/coco'
dataType = 'val2014'
annFile = '{}/annotations/instances_{}.json'.format(dataDir,dataType)
coco = COCO(annFile)
cats = coco.loadCats(coco.getCatIds())
nms = [cat['name'] for cat in cats]
nms = set([cat['supercategory'] for cat in cats])
aaa = []
target = 'truck'
im_seq = 0
ss_Ids = coco.getCatIds(catNms=[target])
for i in ss_Ids:
imgIds = coco.getImgIds(catIds=i)
for img_id in imgIds:
imgIds = coco.getImgIds(imgIds=img_id)
img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
I = io.imread('%s/images/%s/%s'%(dataDir,dataType,img['file_name']))
annIds = coco.getAnnIds(imgIds=img['id'],catIds=i,iscrowd=False)
heights,width = I.shape[0],I.shape[1]
temp = np.ones(I.shape,np.uint8)*255
for each_ann_id in annIds:
anns = coco.loadAnns(each_ann_id)
if(len(anns)!=0):
im_seq = im_seq+1
print('***************----{:09d}--->{}****************'.format(im_seq,target))
fig = plt.figure()
plt.imshow(temp)
plt.axis('off')
fig.set_size_inches(width/100.0,heights/100.0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
plt.margins(0,0)
x,y,w,h = coco.showAnns(anns)
plt.savefig(target+'/msk/{:09d}.jpg'.format(im_seq))
fig1 = plt.figure()
plt.imshow(I)
plt.axis('off')
fig1.set_size_inches(width/100.0,heights/100.0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
plt.margins(0,0)
plt.savefig(target+'/ori/{:09d}.jpg'.format(im_seq))
ff_file = open(target+'/box/{:09d}.txt'.format(im_seq),'w')
ff_file.write(str(x))
ff_file.write(',')
ff_file.write(str(y))
ff_file.write(',')
ff_file.write(str(w))
ff_file.write(',')
ff_file.write(str(h))
ff_file.write('\n')
plt.close(fig)
plt.close(fig1)
coco.py需要修改showAnns函数:
def showAnns(self, anns):
if(len(anns)==0):
return 0
if('segmentation' in anns[0] or 'keypoints' in anns[0]):
datasetType = 'instances'
elseif('caption' in anns[0]):
datasetType = 'captions'
else:
raise Exception('datasetType not supported')
if(datasetType=='instances'):
ax = plt.gca()
as.set_autoscale_on(False)
polygons = []
color = []
for ann in anns:
c = (0,0,0)
if('segmentation' in ann):
if(type(ann['segmentation']) == list):
for seg in ann['segmentation']:
poly = np.array(seg).reshape((int(len(seg)/2),2))
polygons.append(Polygon(poly))
color.append(c)
else:
t = self.imgs[ann['image_id']]
if(type(ann['segmentation']['counts'])==list):
rle = maskUtils.frPyObjects([ann['segmentation']],t['height'],t['width'])
else:
rle = [ann['segmentation']]
m = maskUtils.decode(rle)
img = np.ones((m.shape[0],m.shape[1],3))
if(ann['iscrowd']==1):
color_mask = np.array([2.0,166.0,101.0])/255
if(ann['iscrowd']==0):
color_mask = np.random.randint((1,3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img,m*0.5)))
if('keypoints' in ann and type(ann['keypoints'])==list):
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
kp = np.array(ann['keypoints'])
x = kp[0::3]
y = kp[1::3]
v = kp[2::3]
for sk in sks:
if np.all(v[sk]>0):
plt.plot(x[sk],y[sk],linewidth=3,color=c)
plt.plot(x[v>0],y[v>0],'o',markersize =8,markerfacecolor =c,markeredgecolor=c,markeredgewidth=2)
plt.plot(x[v>1],y[v>1],'o',markersize =8,markerfacecolor =c,markeredgecolor=c,markeredgewidth=2)
p = PatchCollection(polygons,facecolor=color,linewidth=0,alpha=0.4)
ax.add_collection(p)
p = PatchCollection(polygons,facecolor=color,edgecolors=color,linewidth=2)
ax.add_collection(p)
elseif(datasetType=='captions'):
for ann in anns:
print(ann['caption'])
return ann['bbox']
这样以后,如果需要剪裁,可以直接通过读取.txt文件和相应的图像进行剪裁。如下:
import cv2
import os
target = 'truck'
root = '/home/46322/SomeDemos/coco/'+target #存放图像的文本的根目录
for i, j, k in os.walk(root):
for each_msk in k:
msk_path = i+each_msk
box_path = root+'box/'+each_msk.split('.')[0]+'.txt'
ori_path = root +'ori/'+each_msk
bbox = []
for bb in open(box_path):
bbox.append(bb[:len(bb)-1])
x,y,w,h = bbox[0].split(',')
msk_im = cv2.imread(msk_path)
ori_im = cv2.imread(ori_path)
msk_cp = msk_im[int(float(y)):int(float(y))+int(float(h)),int(float(x))+int(float(w))]
ori_cp = ori_im[int(float(y)):int(float(y))+int(float(h)),int(float(x))+int(float(w))]
cv2.imwrite(root+'/msk_crop/'+each_msk,msk_cp)
cv2.imwrite(root+'/ori_crop/'+each_msk,ori_cp)