Simple Baselines for Human Pose Estimation and Tracking源代码学习之get_max_preds()函数

get_max_preds()函数主要从热力图中获取关键点的坐标,完整函数代码如下

def get_max_preds(batch_heatmaps):
    '''
    get predictions from score maps
    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    '''
    assert isinstance(batch_heatmaps, np.ndarray), \
        'batch_heatmaps should be numpy.ndarray'
    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    idx = np.argmax(heatmaps_reshaped, 2)#argmax返回最大值的索引
    maxvals = np.amax(heatmaps_reshaped, 2)#返回数组的最大值,在第2个维度

    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    idx = idx.reshape((batch_size, num_joints, 1))

    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)#np.tile沿着指定维度复制

    preds[:, :, 0] = (preds[:, :, 0]) % width
    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)

    preds *= pred_mask
    return preds, maxvals

该函数在evaluate.py中被调用

 而该accuracy()在function.py中被调用;output = model(input),为模型的输出

 train函数则在train.py中被调用


为了直观方便的看到训练过程中的数据传递,将batchsize改为了1,并输出了代码中各变量的值

assert isinstance(batch_heatmaps, np.ndarray), \
        'batch_heatmaps should be numpy.ndarray'
    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))

此时batch_heatmaps=output,其维度为torch.Size([1, 17, 64, 48]),1表示batch_size,17表示17个关键点,64和48是每个关键点对应热力图的大小

batch_heatmaps.shape=torch.Size([1, 17, 64, 48])

batch_size=1

num_joints=17

width=48

heatmaps_reshaped.shape=torch.Size([1, 17, 3072])

idx = np.argmax(heatmaps_reshaped, 2)#argmax返回最大值的索引
maxvals = np.amax(heatmaps_reshaped, 2)#返回数组的最大值,在第2个维度

idx为热力图中最大值的索引,maxvals为热力图中的最大值;因为17个关键点对应17个热力图,每个热力图求一个最大值(即获取一个关键点),所以总共为17个值

idx=[ [2343 2475 1715 1566 1861 2432 2197 2050 2102 1565 2287 2529 2344 2288,  2768 2058 2910] ]

maxvals=[ [2.039913   1.7724268  1.2334661  0.97012115 1.1990129  1.5435572,  0.8566811  1.2814575  1.0773002  0.91459084 1.4435731  1.5522699,  1.7481376  1.8977993  1.3567188  1.2680566  1.0704352 ] ]

maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))

maxvals=[ [ [2.039913  ],  [1.7724268 ],  [1.2334661 ],  [0.97012115],  [1.1990129 ],  [1.5435572 ],  [0.8566811 ],  [1.2814575 ],  [1.0773002 ],  [0.91459084],  [1.4435731 ],  [1.5522699 ],  [1.7481376 ],  [1.8977993 ],  [1.3567188 ],  [1.2680566 ],  [1.0704352 ] ] ]

idx=[ [ [2343],  [2475],  [1715],  [1566],  [1861],  [2432],  [2197],  [2050],  [2102],  [1565],  [2287],  [2529],  [2344],  [2288],  [2768],  [2058],  [2910] ] ]

此时maxvals.shape=idx.shape=([1, 17, 1])

preds = np.tile(idx, (1, 1, 2)).astype(np.float32)#np.tile沿着指定维度复制

preds=[ [ [2343. 2343.],  [2475. 2475.],  [1715. 1715.],  [1566. 1566.],  [1861. 1861.],  [2432. 2432.],  [2197. 2197.],  [2050. 2050.],  [2102. 2102.],  [1565. 1565.],  [2287. 2287.],  [2529. 2529.],  [2344. 2344.],  [2288. 2288.],  [2768. 2768.],  [2058. 2058.],  [2910. 2910.] ] ]

np.tile()沿着第一维和第二维复制1倍(即不变),将第三维复制使其变为原来的2倍

此时preds.shape=([1, 17, 2])

preds[:, :, 0] = (preds[:, :, 0]) % width

这行代码是求关键点在热力图(热力图大小为48*64)中的横坐标,用热力图的索引除以热力图的宽度,再取余;索引、坐标点是从0开始的,除法取余数也是从0开始,所以直接取余数刚刚好

preds=[ [ [  39. 2343.],  [  27. 2475.],  [  35. 1715.],  [  30. 1566.],  [  37. 1861.],  [  32. 2432.],  [  37. 2197.],  [  34. 2050.],  [  38. 2102.],  [  29. 1565.],  [  31. 2287.],  [  33. 2529.],  [  40. 2344.],  [  32. 2288.],  [  32. 2768.],  [  42. 2058.],  [  30. 2910.] ] ]

preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

这行代码是求关键点在热力图(热力图大小为48*64)中的纵坐标,用热力图的索引除以热力图的宽度,再取商;np.floor()函数的作用是向下取整,因为坐标点是从0开始计的,所以要向下取整

preds=[ [ [39. 48.],  [27. 51.],  [35. 35.],  [30. 32.],  [37. 38.],  [32. 50.],  [37. 45.],  [34. 42.],  [38. 43.],  [29. 32.],  [31. 47.],  [33. 52.],  [40. 48.],  [32. 47.],  [32. 57.],  [42. 42.],  [30. 60.] ] ]

此处以第一个关键点[2343. 2343.]为例,通过取余与取商得到关键点在热力图中的位置为[39. 48.]

 

pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))

np.greater(maxvals, 0.0)=tensor([ [ [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1] ] ], dtype=torch.uint8)

np.greater()用于判断maxvals中的值是否大于0.0,返回一个与maxvals同维度的tensor,如果maxvals某一位置的值大于0.0,则返回的tensor相应位置取1,反之取0

np.tile()将第一维和第二维复制使其变为原来的1倍(即不变),将第三维复制使其变为原来的2倍

pre_mask=[ [ [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1] ] ]

pred_mask = pred_mask.astype(np.float32)

pred_mask=[ [ [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.],  [1. 1.] ] ]

preds *= pred_mask

preds=[ [ [39. 48.],  [27. 51.],  [35. 35.],  [30. 32.],  [37. 38.],  [32. 50.],  [37. 45.],  [34. 42.],  [38. 43.],  [29. 32.],  [31. 47.],  [33. 52.],  [40. 48.],  [32. 47.],  [32. 57.],  [42. 42.],  [30. 60.] ] ]

return preds, maxvals

最后返回关键点的坐标与热力图的最大值

  • 9
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值