读取COCO数据集的关键点坐标

COCO是一个大型的CV数据库,里面包含了包括object detection, keypoints estimation, semantic segmentation,image caption等多个任务所需要的数据库。这里主要介绍一下如何用COCO提供的API读取人体关键点的坐标。关于COCO关节点的评价矩阵,可以参考这个博客

安装COCO API

安装过程可以参照下面这个博客:
https://www.jianshu.com/p/de455d653301
如果你是Linux用户,那么基本上不会出现什么问题,直接make COCO的API就好,但是你如果是Windows的用户的话,比如我,就极其容易出现问题,好在上面那个博客基本上可以解决这个问题。

提取COCO数据集包含人的Keypoints标注

话不多说,直接上python的代码:

# @author: zhangboshen
# @Email: zhangbs@whu.edu.cn 
# 
# 提取COCO关键点并保存在CSV文件中 Date: 2018.3.22

from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import pylab
import os
from PIL import Image
from PIL import ImageDraw
import csv
pylab.rcParams['figure.figsize'] = (8.0, 10.0)

# initialize COCO api for person keypoints annotations
dataDir='..'
dataType='train2017'
annFile = '{}/annotations/person_keypoints_{}.json'.format(dataDir,dataType)
coco_kps=COCO(annFile)

# display COCO categories and supercategories
cats = coco_kps.loadCats(coco_kps.getCatIds())
nms=[cat['name'] for cat in cats]
print('COCO categories: \n{}\n'.format(' '.join(nms)))

nms = set([cat['supercategory'] for cat in cats])
print('COCO supercategories: \n{}'.format(' '.join(nms)))

# get all images containing given categories, select one at random
catIds = coco_kps.getCatIds(catNms=['person']);
imgIds = coco_kps.getImgIds(catIds=catIds );
print ('there are %d images containing human'%len(imgIds))

def getBndboxKeypointsGT():
    csvFile = open('....../KeypointBndboxGT.csv','wb') 
    keypointsWriter = csv.writer(csvFile)
    firstRow = ['imageName','personNumber','bndbox','nose',
            'left_eye','right_eye','left_ear','right_ear','left_shoulder','right_shoulder',
            'left_elbow','right_elbow','left_wrist','right_wrist','left_hip','right_hip',
            'left_knee','right_knee','left_ankle','right_ankle']
    keypointsWriter.writerow(firstRow)
    for i in range(len(imgIds)):
        imageNameTemp = coco_kps.loadImgs(imgIds[i])[0]
        imageName = imageNameTemp['file_name'].encode('raw_unicode_escape')
        img = coco_kps.loadImgs(imgIds[i])[0]
        annIds = coco_kps.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
        anns = coco_kps.loadAnns(annIds)
        personNumber = len(anns)
        for j in range(personNumber):
            bndbox = anns[j]['bbox']
            keyPoints = anns[j]['keypoints']
            keypointsRow = [imageName,str(personNumber),
                            str(bndbox[0])+'_'+str(bndbox[1])+'_'+str(bndbox[2])+'_'+str(bndbox[3]),
                            str(keyPoints[0])+'_'+str(keyPoints[1])+'_'+str(keyPoints[2]),
                            str(keyPoints[3])+'_'+str(keyPoints[4])+'_'+str(keyPoints[5]),
                            str(keyPoints[6])+'_'+str(keyPoints[7])+'_'+str(keyPoints[8]),
                            str(keyPoints[9])+'_'+str(keyPoints[10])+'_'+str(keyPoints[11]),
                            str(keyPoints[12])+'_'+str(keyPoints[13])+'_'+str(keyPoints[14]),
                            str(keyPoints[15])+'_'+str(keyPoints[16])+'_'+str(keyPoints[17]),
                            str(keyPoints[18])+'_'+str(keyPoints[19])+'_'+str(keyPoints[20]),
                            str(keyPoints[21])+'_'+str(keyPoints[22])+'_'+str(keyPoints[23]),
                            str(keyPoints[24])+'_'+str(keyPoints[25])+'_'+str(keyPoints[26]),
                            str(keyPoints[27])+'_'+str(keyPoints[28])+'_'+str(keyPoints[29]),
                            str(keyPoints[30])+'_'+str(keyPoints[31])+'_'+str(keyPoints[32]),
                            str(keyPoints[33])+'_'+str(keyPoints[34])+'_'+str(keyPoints[35]),
                            str(keyPoints[36])+'_'+str(keyPoints[37])+'_'+str(keyPoints[38]),
                            str(keyPoints[39])+'_'+str(keyPoints[40])+'_'+str(keyPoints[41]),
                            str(keyPoints[42])+'_'+str(keyPoints[43])+'_'+str(keyPoints[44]),
                            str(keyPoints[45])+'_'+str(keyPoints[46])+'_'+str(keyPoints[47]),
                            str(keyPoints[48])+'_'+str(keyPoints[49])+'_'+str(keyPoints[50]),]

            keypointsWriter.writerow(keypointsRow)

    csvFile.close()

if __name__ == "__main__":
    print ('Writing bndbox and keypoints to csv files..."')
    getBndboxKeypointsGT()       
最后的CSV文件包含了:图片名字;单张图片包含的人的数量;对应的boundingbox;以及17个点的二维坐标。
  • 0
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值