Simple Baselines for Human Pose Estimation and Tracking源代码学习之accuracy()函数

该函数用于计算关键点预测的准确率

def accuracy(output, target, hm_type='gaussian', thr=0.5):
    idx = list(range(output.shape[1]))
    norm = 1.0
    if hm_type == 'gaussian':
        pred, _ = get_max_preds(output)
        target, _ = get_max_preds(target)
        h = output.shape[2]
        w = output.shape[3]
        norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
    dists = calc_dists(pred, target, norm)

    acc = np.zeros((len(idx) + 1))
    avg_acc = 0
    cnt = 0

    for i in range(len(idx)):
        acc[i + 1] = dist_acc(dists[idx[i]])
        if acc[i + 1] >= 0:
            avg_acc = avg_acc + acc[i + 1]
            cnt += 1

    avg_acc = avg_acc / cnt if cnt != 0 else 0
    if cnt != 0:
        acc[0] = avg_acc
    return acc, avg_acc, cnt, pred

下面依次输出各变量的值,方便对代码进行理解

idx = list(range(output.shape[1]))
norm = 1.0

idx=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]

 if hm_type == 'gaussian':
     pred, _ = get_max_preds(output)
     target, _ = get_max_preds(target)
     h = output.shape[2]
     w = output.shape[3]
     norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10

pred=[ [ [39. 48.],  [27. 51.],  [35. 35.],  [30. 32.],  [37. 38.],  [32. 50.],  [37. 45.],  [34. 42.],  [38. 43.],  [29. 32.],  [31. 47.],  [33. 52.],  [40. 48.],  [32. 47.],  [32. 57.],  [42. 42.],  [30. 60.] ] ]

target=[ [ [ 0.  0.],  [ 0.  0.],  [ 0.  0.],  [34. 11.],  [39. 14.],  [25. 13.],  [38. 18.],  [19. 11.],  [40. 27.],  [ 8.  7.],  [43. 38.],  [22. 35.],  [31. 31.],  [24. 44.],  [38. 31.],  [27. 54.],  [38. 41.] ] ]

_=[ [ [0.],  [0.],  [0.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.],  [1.] ] ]

h=64

w=48

norm=[ [6.4 4.8] ]

dists = calc_dists(pred, target, norm)

调用calc_dists()函数,进入calc_dists函数


calc_dists()函数

此函数用于计算预测关键点到真实关键点之间的距离(使用的是欧氏距离)

def calc_dists(preds, target, normalize):
    preds = preds.astype(np.float32)
    target = target.astype(np.float32)
    dists = np.zeros((preds.shape[1], preds.shape[0]))
    for n in range(preds.shape[0]):
        for c in range(preds.shape[1]):
            if target[n, c, 0] > 1 and target[n, c, 1] > 1:
                normed_preds = preds[n, c, :] / normalize[n]
                normed_targets = target[n, c, :] / normalize[n]
                dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
            else:
                dists[c, n] = -1
    return dists
preds = preds.astype(np.float32)
target = target.astype(np.float32)
dists = np.zeros((preds.shape[1], preds.shape[0]))

preds=[ [ [39. 48.],  [27. 51.],  [35. 35.],  [30. 32.],  [37. 38.],  [32. 50.],  [37. 45.],  [34. 42.],  [38. 43.],  [29. 32.],  [31. 47.],  [33. 52.],  [40. 48.],  [32. 47.],  [32. 57.],  [42. 42.],  [30. 60.] ] ]

target=[ [ [ 0.  0.],  [ 0.  0.],  [ 0.  0.],  [34. 11.],  [39. 14.],  [25. 13.],  [38. 18.],  [19. 11.],  [40. 27.],  [ 8.  7.],  [43. 38.],  [22. 35.],  [31. 31.],  [24. 44.],  [38. 31.],  [27. 54.],  [38. 41.] ] ]

dists =[ [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] [0.] ]

preds.shape[1]=17

preds.shape[0]=1

for n in range(preds.shape[0]):
    for c in range(preds.shape[1]):
        if target[n, c, 0] > 1 and target[n, c, 1] > 1:
             normed_preds = preds[n, c, :] / normalize[n]
             normed_targets = target[n, c, :] / normalize[n]
             dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
        else:
             dists[c, n] = -1

遍历target中的坐标,如果该坐标的横坐标值与纵坐标值同时大于1,说明该关键点存在,进入if条件语句;反之表示该关键点不存在,进入else语句,返回-1;这里取n=0、c=3时的一次循环

normed_preds=preds[0, 3, :] / normalize[0]

则normed_preds=[30. 32.]/[6.4 4.8]=[4.6875     6.66666667]

normed_targets = target[0, 3, :] / normalize[0]

则normed_preds=[34., 11.]/[6.4 4.8]=[5.3125     2.29166667]

normed_preds - normed_targets=[-0.625  4.375]

dists[3,0]=np.linalg.norm(normed_preds - normed_targets)= [(-0.625)^2+4.375^2]^0.5 = 4.419417382415922

return dists

最后将dists返回,dists中存放的是模型预测的关键点到实际关键点的距离


dists=[ [-1.        ], [-1.        ], [-1.        ], [ 4.41941738], [ 5.00975611], [ 7.78554377], [ 5.62716972], [ 6.87046094], [ 3.34794972], [ 6.15575647], [ 2.65165043], [ 3.93668698], [ 3.81063536], [ 1.39754249], [ 5.49719783], [ 3.42683003], [ 4.15101226] ]

acc = np.zeros((len(idx) + 1))
avg_acc = 0
cnt = 0

acc=[ 0. -1. -1. -1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

for i in range(len(idx)):
    acc[i + 1] = dist_acc(dists[idx[i]])
    if acc[i + 1] >= 0:
        avg_acc = avg_acc + acc[i + 1]
        cnt += 1

len(idx)=17

此处调用了dist_acc()函数

cnt统计真实存在的关键点个数

avg_acc统计预测正确的关键点个数


dist_acc()函数

此函数用于判断预测关键点到真实关键点之间的距离是否小于阈值0.5;小于0.5说明预测正确,返回1;反之预测不正确,返回0;如果关键点不存在,则返回-1

def dist_acc(dists, thr=0.5):
    ''' Return percentage below threshold while ignoring values with a -1 '''
    dist_cal = np.not_equal(dists, -1)
    num_dist_cal = dist_cal.sum()
    if num_dist_cal > 0:
        return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal
    else:
        return -1

以dists[0]=[-1.        ],调用dist_acc函数,则

dist_cal=[False]

num_dist_cal=0,不进入if条件语句

以dists[3]=[ 4.41941738],调用dist_acc函数,则进入函数后

dists=[ 4.41941738]

dist_cal=[ True]

num_dist_cal=1

此时进入if判断语句,np.less(a, b)用于判断a是否小于b,是的话返回True,反之返回False

dists[dist_cal]=dists[ [True] ]=[ 4.41941738]

np.less([4.41941738], 0.5)=False

np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal=0.0


以dists[0]=[-1.        ]调用dist_acc函数,返回-1,不进入if判断语句

以dists[3]=[ 4.41941738]调用dist_acc函数,返回0.0,进入if判断语句后

avg_acc=0.0

cnt=1

avg_acc = avg_acc / cnt if cnt != 0 else 0
if cnt != 0:
    acc[0] = avg_acc
return acc, avg_acc, cnt, pred

平均准确率=预测正确的关键点个数/真实存在的关键点个数

此时cnt=14

avg_acc=0.0

acc=[ 0. -1. -1. -1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

最后将准确率、平均准确率、真实的关键点个数、预测的关键点返回

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值