该函数用于计算关键点预测的准确率
def accuracy(output, target, hm_type='gaussian', thr=0.5):
idx = list(range(output.shape[1]))
norm = 1.0
if hm_type == 'gaussian':
pred, _ = get_max_preds(output)
target, _ = get_max_preds(target)
h = output.shape[2]
w = output.shape[3]
norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
dists = calc_dists(pred, target, norm)
acc = np.zeros((len(idx) + 1))
avg_acc = 0
cnt = 0
for i in range(len(idx)):
acc[i + 1] = dist_acc(dists[idx[i]])
if acc[i + 1] >= 0:
avg_acc = avg_acc + acc[i + 1]
cnt += 1
avg_acc = avg_acc / cnt if cnt != 0 else 0
if cnt != 0:
acc[0] = avg_acc
return acc, avg_acc, cnt, pred
下面依次输出各变量的值,方便对代码进行理解
idx = list(range(output.shape[1]))
norm = 1.0
idx=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
if hm_type == 'gaussian':
pred, _ = get_max_preds(output)
target, _ = get_max_preds(target)
h = output.shape[2]
w = output.shape[3]
norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
pred=[ [ [39. 48.], [27. 51.], [35. 35.], [30. 32.], [37. 38.], [32. 50.], [37. 45.], [34. 42.], [38. 43.], [29. 32.], [31. 47.], [33. 52.], [40. 48.], [32. 47.], [32. 57.], [42. 42.], [30. 60.] ] ]
target=[ [ [ 0. 0.], [ 0. 0.], [ 0. 0.], [34. 11.], [39. 14.], [25. 13.], [38. 18.], [19. 11.], [40. 27.], [ 8. 7.], [43. 38.], [22. 35.], [31. 31.], [24. 44.], [38. 31.], [27. 54.], [38. 41.] ] ]
_=[ [ [0.], [0.], [0.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.] ] ]
h=64
w=48
norm=[ [6.4 4.8] ]
dists = calc_dists(pred, target, norm)
调用calc_dists()函数,进入calc_dists函数
calc_dists()函数
此函数用于计算预测关键点到真实关键点之间的距离(使用的是欧氏距离)
def calc_dists(preds, target, normalize):
preds = preds.astype(np.float32)
target = target.astype(np.float32)
dists = np.zeros((preds.shape[1], preds.shape[0]))
for n in range(preds.shape[0]):
for c in range(preds.shape[1]):
if target[n, c, 0] > 1 and target[n, c, 1] > 1:
normed_preds = preds[n, c, :] / normalize[n]
normed_targets = target[n, c, :] / normalize[n]
dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
else:
dists[c, n] = -1
return dists
preds = preds.astype(np.float32)
target = target.astype(np.float32)
dists = np.zeros((preds.shape[1], preds.shape[0]))
preds=[ [ [39. 48.], [27. 51.], [35. 35.], [30. 32.], [37. 38.], [32. 50.], [37. 45.], [34. 42.], [38. 43.], [29. 32.], [31. 47.], [33. 52.], [40. 48.], [32. 47.], [32. 57.], [42. 42.], [30. 60.] ] ]
target=[ [ [ 0. 0.], [ 0. 0.], [ 0. 0.], [34. 11.], [39. 14.], [25. 13.], [38. 18.], [19. 11.], [40. 27.], [ 8. 7.], [43. 38.], [22. 35.], [31. 31.], [24. 44.], [38. 31.], [27. 54.], [38. 41.] ] ]
dists =[ [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] ]
preds.shape[1]=17
preds.shape[0]=1
for n in range(preds.shape[0]):
for c in range(preds.shape[1]):
if target[n, c, 0] > 1 and target[n, c, 1] > 1:
normed_preds = preds[n, c, :] / normalize[n]
normed_targets = target[n, c, :] / normalize[n]
dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
else:
dists[c, n] = -1
遍历target中的坐标,如果该坐标的横坐标值与纵坐标值同时大于1,说明该关键点存在,进入if条件语句;反之表示该关键点不存在,进入else语句,返回-1;这里取n=0、c=3时的一次循环
normed_preds=preds[0, 3, :] / normalize[0]
则normed_preds=[30. 32.]/[6.4 4.8]=[4.6875 6.66666667]
normed_targets = target[0, 3, :] / normalize[0]
则normed_preds=[34., 11.]/[6.4 4.8]=[5.3125 2.29166667]
normed_preds - normed_targets=[-0.625 4.375]
dists[3,0]=np.linalg.norm(normed_preds - normed_targets)= [(-0.625)^2+4.375^2]^0.5 = 4.419417382415922
return dists
最后将dists返回,dists中存放的是模型预测的关键点到实际关键点的距离
dists=[ [-1. ], [-1. ], [-1. ], [ 4.41941738], [ 5.00975611], [ 7.78554377], [ 5.62716972], [ 6.87046094], [ 3.34794972], [ 6.15575647], [ 2.65165043], [ 3.93668698], [ 3.81063536], [ 1.39754249], [ 5.49719783], [ 3.42683003], [ 4.15101226] ]
acc = np.zeros((len(idx) + 1))
avg_acc = 0
cnt = 0
acc=[ 0. -1. -1. -1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
for i in range(len(idx)):
acc[i + 1] = dist_acc(dists[idx[i]])
if acc[i + 1] >= 0:
avg_acc = avg_acc + acc[i + 1]
cnt += 1
len(idx)=17
此处调用了dist_acc()函数
cnt统计真实存在的关键点个数
avg_acc统计预测正确的关键点个数
dist_acc()函数
此函数用于判断预测关键点到真实关键点之间的距离是否小于阈值0.5;小于0.5说明预测正确,返回1;反之预测不正确,返回0;如果关键点不存在,则返回-1
def dist_acc(dists, thr=0.5):
''' Return percentage below threshold while ignoring values with a -1 '''
dist_cal = np.not_equal(dists, -1)
num_dist_cal = dist_cal.sum()
if num_dist_cal > 0:
return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal
else:
return -1
以dists[0]=[-1. ],调用dist_acc函数,则
dist_cal=[False]
num_dist_cal=0,不进入if条件语句
以dists[3]=[ 4.41941738],调用dist_acc函数,则进入函数后
dists=[ 4.41941738]
dist_cal=[ True]
num_dist_cal=1
此时进入if判断语句,np.less(a, b)用于判断a是否小于b,是的话返回True,反之返回False
dists[dist_cal]=dists[ [True] ]=[ 4.41941738]
np.less([4.41941738], 0.5)=False
np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal=0.0
以dists[0]=[-1. ]调用dist_acc函数,返回-1,不进入if判断语句
以dists[3]=[ 4.41941738]调用dist_acc函数,返回0.0,进入if判断语句后
avg_acc=0.0
cnt=1
avg_acc = avg_acc / cnt if cnt != 0 else 0
if cnt != 0:
acc[0] = avg_acc
return acc, avg_acc, cnt, pred
平均准确率=预测正确的关键点个数/真实存在的关键点个数
此时cnt=14
avg_acc=0.0
acc=[ 0. -1. -1. -1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
最后将准确率、平均准确率、真实的关键点个数、预测的关键点返回