截取coco数据集方法

coco数据集十分庞大,最近打算从里面截取一个只包含person、bicycle、bus三个类别,图片总数约1000张的小型数据集,在此记录截取过程。

step1、将json瘦身,只包含这三个类别,segmentation不需要可以去掉,为了后续可以对应到图片,增加file_name字段存储图片路径。

import json
import hashlib
from tqdm import tqdm
import time
import os

def filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2):
    st = time.time()
    with open(input_path, 'r') as f:
        data = json.load(f)
    print('read cost: {}'.format(time.time() - st))
    st = time.time()
    
    id2name = {}
    for img in tqdm(data['images']):
        id2name[img['id']] = img['file_name']

    categories_sample = [
    {
      "id": 1,
      "name": "person"
    },
    {
      "id": 2,
      "name": "bicycle"
    }
    ,
    {
      "id": 6,
      "name": "bus"
    }
  ]

    newanns1 = {'annotations': [],   "categories": categories_sample}
    cate_counts = {1:0, 2:0, 6:0}

    
    imgpath = '/data/det_coco2017/train2017/'
    annnumber = 0
    for ann in tqdm(data['annotations']):
        if ann['category_id'] in category_ids:
            cate_counts[ann['category_id']] += 1
            annnumber += 1
            ann.pop('segmentation', None)
            ann['file_name'] = imgpath + id2name[ann['image_id']]
            ann['md5'] = get_md5(imgpath, ann['file_name'])
            newanns1['annotations'].append(ann) 
        
    print(cate_counts)
    
    with open(output_path_1, 'w') as f:
        json.dump(newanns1, f)

def get_md5(folder_path, image_file):
    with open(os.path.join(folder_path, image_file), 'rb') as f:
        image_data = f.read()
        md5 = hashlib.md5(image_data).hexdigest()
    return md5

if __name__ == '__main__':
    input_path = '/data/det_coco2017/annotations/instances_train2017.json'
    output_path_1 = 'train.json'
    output_path_2 = 'test.json'
    category_ids = [1, 2, 6]
    num_images_1 = 1000
    num_images_2 = 200

    filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2)

step2、通过以下指令查看各类数目

cat train2.json | jq '[.annotations[] | select(.category_id == 6)] | length'

step3、取固定图片数目,重新生成json

import json
import hashlib
from tqdm import tqdm
import time
import os

def filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2):
    st = time.time()
    with open(input_path, 'r') as f:
        data = json.load(f)
    print('read cost: {}'.format(time.time() - st))
    st = time.time()
    
    img_ids = []
    for ann in tqdm(data['annotations']):
      if ann['category_id'] in [2,6]:
        img_ids.append(ann['image_id'])

    unique_list = list(set(img_ids))
    unique_list = unique_list[:200]

    newann = []
    for ann in tqdm(data['annotations']):
      if ann['image_id'] in unique_list:
        newann.append(ann)

    data['annotations'] = newann
        
    print(unique_list, ' imgs')
    
    with open(output_path_1, 'w') as f:
        json.dump(data, f)

def get_md5(folder_path, image_file):
    with open(os.path.join(folder_path, image_file), 'rb') as f:
        image_data = f.read()
        md5 = hashlib.md5(image_data).hexdigest()
    return md5

if __name__ == '__main__':
    input_path = '/data/det_coco2017/train.json'
    output_path_1 = 'test.json'
    output_path_2 = 'test.json'
    category_ids = [1, 2, 6]
    num_images_1 = 1000
    num_images_2 = 200

    filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2)

step4、将图片单独放到文件夹

cat test.json | jq -r '.annotations[].file_name' | xargs -I {} cp {} test/

step5、打包

zip -r test.zip test/ test.json 
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值