数据集格式转换代码1:coco2VOC

数据集格式转换代码1:coco2VOC

from pycocotools.coco import COCO
import os
import shutil
from tqdm import tqdm
import skimage.io as io
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw


#20220817 增加 category_id 目标类型字段。
 #提取特定的类到指定目录,包括图片文件和生成VOC标签文件xml
#1. 特定的类:修改 classes_names
#2.修改输出 目录 objfiles_name
#3.修改需要提取的数据量: 
    #  trainNumMax = 120000#训练集 取50000张
    #  valNumMax = 80000#验证集 取20000张
#coco有80类,这里写要提取类的名字,以person为例 
classes_names = ['person','bicycle','cat', 'dog'] 
# classes_names = ['cat', 'dog'] 

# 需要设置的路径
# objfiles_name = "catdog/"
objfiles_name = "class-mini/"

savepath="/media/tu/5DFD7B197A9B3CAB/dataset/coco2yolo/"+ objfiles_name 
# savepath="./dataset/coco2yolo/+catdog/" 
img_dir=savepath+'images/'
anno_dir=savepath+'annotations_VOC/'
datasets_list=['train2017', 'val2017']


# trainNumMax = 120000#训练集 取50000张
# valNumMax = 80000#验证集 取20000张

trainNumMax = 20000#训练集      取 x 张
valNumMax = 5000#验证集          取 x 张

#包含所有类别的原coco数据集路径
'''
目录格式如下:
$COCO_PATH
----|annotations
----|train2017
----|val2017
----|test2017
'''
dataDir= '/media/tu/5DFD7B197A9B3CAB/dataset/coco2017/'
# dataDir= '/path/to/coco_orgi/' 
 
headstr = """\
<annotation>
    <folder>VOC</folder>
    <filename>%s</filename>
    <source>
        <database>My Database</database>
        <annotation>COCO</annotation>
        <image>flickr</image>
        <flickrid>NULL</flickrid>
    </source>
    <owner>
        <flickrid>NULL</flickrid>
        <cp_name>company</cp_name>
    </owner>
    <size>
        <width>%d</width>
        <height>%d</height>
        <depth>%d</depth>
    </size>
    <segmented>0</segmented>
"""
#20220817增加字段: <category_id>%d</category_id> 标识目标类型
objstr = """\
    <object>
        <name>%s</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <category_id>%d</category_id>
        <bndbox>
            <xmin>%d</xmin>
            <ymin>%d</ymin>
            <xmax>%d</xmax>
            <ymax>%d</ymax>
        </bndbox>
    </object>
"""
 
tailstr = '''\
</annotation>
'''
 
# 检查目录是否存在,如果存在,先删除再创建,否则,直接创建
def mkr(path):
    if not os.path.exists(path):
        os.makedirs(path)  # 可以创建多级目录

def id2name(coco):
    classes=dict()
    for cls in coco.dataset['categories']:
        classes[cls['id']]=cls['name']
    return classes
 
def write_xml(anno_path,head, objs, tail):
    f = open(anno_path, "w")
    f.write(head)
    for obj in objs:
        f.write(objstr%(obj[0],obj[1],obj[2],obj[3],obj[4],obj[5]))
    f.write(tail)
 
 
def save_annotations_and_imgs(coco,dataset,filename,objs):
    #将图片转为xml,例:COCO_train2017_000000196610.jpg-->COCO_train2017_000000196610.xml
    dst_anno_dir = os.path.join(anno_dir, dataset)
    mkr(dst_anno_dir)
    anno_path=dst_anno_dir + '/' + filename[:-3]+'xml'
    img_path=dataDir+dataset+'/'+filename
    # print("img_path: ", img_path)
    dst_img_dir = os.path.join(img_dir, dataset)
    mkr(dst_img_dir)
    dst_imgpath=dst_img_dir+ '/' + filename
    # print("dst_imgpath: ", dst_imgpaobjsth)
    img=cv2.imread(img_path)
    #if (img.shape[2] == 1):
    #    print(filename + " not a RGB image")
     #   return
    if os.path.exists(dst_imgpath) ==False:
        shutil.copy(img_path, dst_imgpath) 
        head=headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
        tail = tailstr
        write_xml(anno_path,head, objs, tail)
 
 
def showimg(clsNum,coco,dataset,img,classes,cls_id,show=True):
    global dataDir
    dri = '%s%s/%s'%(dataDir,dataset,img['file_name'])
    I=Image.open(dri)
    #通过id,得到注释的信息
    annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
    # print(annIds)
    anns = coco.loadAnns(annIds)
    # print(anns)
    # coco.showAnns(anns)
    objs = []
    for ann in anns:
        class_name=classes[ann['category_id']]
        if class_name in classes_names:
            # print(f"class_name:{class_name}")
            # category_id = int(ann['category_id'])
            category_id = int(clsNum)
            # print(f"category_id:{category_id}")
            if 'bbox' in ann:                
                bbox=ann['bbox']
                xmin = int(bbox[0])
                ymin = int(bbox[1])
                xmax = int(bbox[2] + bbox[0])
                ymax = int(bbox[3] + bbox[1])
                obj = [class_name,category_id, xmin, ymin, xmax, ymax]
                objs.append(obj)
    #             draw = ImageDraw.Draw(I)
    #             draw.rectangle([xmin, ymin, xmax, ymax])
    # if show:
    #     plt.figure()
    #     plt.axis('off')
    #     plt.imshow(I)
    #     plt.show()
 
    return objs

def main():

    for dataset in datasets_list:
        # if 'train2017'==dataset:#训练集 不取
        #     continue   
        # if 'val2017'==dataset:#验证集 不取
        #     continue   
        #./COCO/annotations/instances_train2017.json
        annFile='{}/annotations/instances_{}.json'.format(dataDir,dataset)
        print(annFile)
        print(dataset)
        #使用COCO API用来初始化注释数据
        coco = COCO(annFile)
    
        #获取COCO数据集中的所有类别
        classes = id2name(coco)
        print(classes)
        # return
        #[1, 2, 3, 4, 6, 8]
        classes_ids = coco.getCatIds(catNms=classes_names)
        print(classes_ids)
        clsNum = 0
        for cls in classes_names: #提取 每一个类型
            clsNum +=1
            #获取该类的id
            cls_id=coco.getCatIds(catNms=[cls])
            img_ids=coco.getImgIds(catIds=cls_id)
            print(cls,' img_ids: ',len(img_ids))
            print(' clsNum: ',clsNum)
            # imgIds=img_ids[0:10]
            numCnt = 0
            for imgId in tqdm(img_ids):
                img = coco.loadImgs(imgId)[0]
                filename = img['file_name']
                # print(filename)
                objs=showimg(clsNum,coco, dataset, img, classes,classes_ids,show=False)
                # print(objs) # [['cat', 3, 117, 86, 368, 239]]
                save_annotations_and_imgs(coco, dataset, filename, objs)
                numCnt+=1
                if 'train2017'==dataset:#训练集取 x 张
                    if(numCnt>trainNumMax):
                        print('train2017, out put sum:',numCnt)
                        break
                if 'val2017'==dataset:#验证集 取 x 张
                    if(numCnt>valNumMax):
                        print('val2017, out put sum:',numCnt)
                        break

if __name__ == '__main__':
     main()
     
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值