3.keypoint关键点检测-COCO格式转为YOLO格式
网上没找到有关 关键点检测-COCO格式转为YOLO格式的相关代码,就自己参考其他人基于目标检测的进行了改写,话不多说,直接上代码
如果对yolo和mscoco格式还不了解的可以参考我的这篇博文【2.labelme转yolo格式和MS COCO格式】
import os
import json
# 将目标检测框的信息转为yolo格式
def cc2yolo_bbox(img_width, img_height, bbox):
dw = 1. / img_width
dh = 1. / img_height
# yolo中框的坐标信息为 框中心点的横纵坐标及其宽高
x = bbox[0] + bbox[2] / 2.0
y = bbox[1] + bbox[3] / 2.0
w = bbox[2]
h = bbox[3]
# 归一化坐标信息
x = format(x * dw ,'.5f')
w = format(w * dw ,'.5f')
y = format(y * dh ,'.5f')
h = format(h * dh ,'.5f')
return (x, y, w, h)
# 将关键点keypoints的信息转为yolo格式
def cc2yolo_keypoints(img_width, img_height, keypoints):
list=[]
dw = 1. / img_width
dh = 1. / img_height
keypoint_num = len(keypoints)
for i in range(keypoint_num):
# 每个关键点的横坐标数据
if i % 3 == 0:
list.append(format(keypoints[i]*dw,'.5f'))
# 每个关键点的纵坐标数据
if i % 3 == 1:
list.append(format(keypoints[i]*dh,'.5f'))
# 每个关键点的可见性(0表示没出现在图中,1表示出现在图中但被遮挡,2表示出现在图中且未被遮挡)
if i % 3 == 2:
list.append(keypoints[i])
result = tuple(list)
return result
# 指定COCO格式数据地址
json_file_path = r'E:\person_keypoints_val2017.json'
data = json.load(open(json_file_path, 'r'))
# 指定YOLO格式数据存放地址
yolo_anno_path = r'./yolo_anno/'
if not os.path.exists(yolo_anno_path):
os.makedirs(yolo_anno_path)
# 由于coco id是不连续的,会导致后面报错,所以这里生成一个map映射
cate_id_map = {}
num = 0
for cate in data['categories']:
cate_id_map[cate['id']] = num
num+=1 # cate_id_map -> {87: 0, 1034: 1, 131: 2, 318: 3, 588: 4}
cate_id_map
for img in data['images']:
# 获取图片文件名
filename = img['file_name']
# 图片宽度
img_width = img['width']
# 图片高度
img_height = img['height']
# 图片id
img_id = img['id']
# 生成的yolo格式标注的文件名
yolo_txt_name = filename.split('.')[0] + '.txt'
with open(yolo_anno_path+yolo_txt_name, 'w') as f:
# 遍历所有标注信息
for anno in data['annotations']:
# 若此标注中图片id等于所需的图片id
if anno['image_id'] == img_id:
f.write(str(cate_id_map[anno['category_id']]) + ' ')
bbox_info = cc2yolo_bbox(img_width, img_height, anno['bbox'])
keypoints_info = cc2yolo_keypoints(img_width, img_height, anno['keypoints'])
for item in bbox_info:
f.write(item + ' ')
for item in keypoints_info:
f.write(str(item) + ' ')
f.write('\n')
f.close()