import json
import os
import argparse
from shutil import copy
def is_300(data, points): # 判断当前id的总帧数是否为300帧
if len(data) > 300:
for i in range(300, len(data)):
del data[str(i)]
if len(data) < 300:
for i in range(len(data), 300):
data[str(i)] = [0.0] * (3 * points + 4)
return data
def data_normalization(total, points): # 将2d骨架坐标数据归一化
for i in range(len(total)):
for j in range(points):
if i in total.keys():
total[i][3*j] = float(total[i][3*j]) / float(1080)
total[i][3*j+1] = float(total[i][3*j+1]) / float(652)
if str(i) in total.keys():
total[str(i)][3 * j] = float(total[str(i)][3 * j]) / float(1080)
total[str(i)][3 * j + 1] = float(total[str(i)][3 * j + 1]) / float(652)
return total
def data_generate(path, result, dir_avi, points): # 将每一帧的数据转换为每一个id在每一帧的骨架数据,及其检测框位置
j = 0
id = []
total = {}
if not os.path.exists(result):
os.mkdir(result)
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
for i in range(len(data)):
if data[i]['idx'] not in id:
id.append(data[i]['idx'])
for i in id:
total[i] = {0 : [0]}
id_tmp = []
for i in range(len(data)):
image_id = int(data[i]['image_id'].split('.')[0])
if image_id == j:
idx = data[i]['idx']
id_tmp.append(idx)
total[idx][j] = data[i]['keypoints']
total[idx][j].append(data[i]['box'])
else:
for k in id:
if k not in id_tmp:
total[k][j] = [0.0] * (points * 3 + 4)
j += 1
id_tmp = []
idx = data[i]['idx']
id_tmp.append(idx)
total[idx][j] = data[i]['keypoints']
total[idx][j].append(data[i]['box'])
if i == len(data) - 1:
for k in id:
if k not in id_tmp:
total[k][j] = [0.0] * (points * 3 + 4)
break
for i in id:
total[i] = is_300(total[i], points)
total[i] = data_normalization(total[i], points)
path_tmp = dir_avi + '_' + str(i) + '.json'
with open(result + '/' + path_tmp, 'w') as f:
json.dump(total[i], f, indent=4, ensure_ascii=False)
def compare_num(scores, threshold): # 判断置信度低于某一阈值的
scores.sort()
num = 0
for i in scores: # the num of scores <= threshold
if i <= threshold:
num += 1
else:
return num
def data_process(result, points, score_threshold, rate_error):
remove_file = []
for root, dir, files in os.walk(result):
for file in files:
with open(root + '/' + file, 'r', encoding='utf-8') as f:
data = json.load(f)
num_null = 0
num_no_scores_frames = 0
for i in range(len(data)):
frames = len(data)
if data[str(i)][:3*points].count(0.0) == points * 3:
num_null += 1
else:
scores = data[str(i)][:3*points][2::3]
num_no_scores = compare_num(scores, score_threshold)
if num_no_scores >= points * 0.3:
num_no_scores_frames += 1
print("空帧和被遮挡帧个数:%d, %d"%(num_null, num_no_scores_frames))
if num_null >= float(frames) * rate_error or num_no_scores_frames >= float(frames) * rate_error:
remove_file.append(root + '/' + file)
print('%s下被删除的json文件' % result)
print(remove_file)
for f in remove_file:
os.remove(f)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('--result_dir', type=str, default='result_0716') # 输出文件夹
parser.add_argument('--avi_dir', type=str, default='avi') # alphapose的输出文件夹
args = parser.parse_args()
result = args.result_dir
path_avi = args.avi_dir
points = 18 # 关键点个数
score_threshold = 0.3 # 置信度阈值
rate_error = 0.2 # 容错率
result = result + '/'
path_avi = path_avi + '/'
if not os.path.exists(result):
os.mkdir(result)
for root, dir, path in os.walk(path_avi):
for path_single in path:
follow_name = path_single.split('.')[1]
dir_avi = path_single.split('.')[0]
if follow_name == 'json':
result_avi = result + dir_avi + '/'
path_final = path_avi + path_single
data_generate(path_final, result_avi, dir_avi, points)
if follow_name == 'avi':
old_path = root + path_single
new_path = result + path_single
copy(old_path, new_path)
for root, dir, path in os.walk(result):
for path_dir in dir:
path_result = root + path_dir
data_process(path_result, points, score_threshold, rate_error)
alphapose输出2d骨架数据转化为每一个id在每一帧的2d骨架数据及检测框位置
最新推荐文章于 2023-06-08 17:33:25 发布