AlphaPose源代码学习之getPrediction()函数

该函数的作用是从预测的热力图中获取关键点,并将热力图中的关键点缩放为原图上的关键点,该函数位于eval.py中,完整代码如下 

def getPrediction(hms, pt1, pt2, inpH, inpW, resH, resW):
    '''
    从热力图中获得关键点
    '''
    assert hms.dim() == 4, 'Score maps should be 4-dim'
    maxval, idx = torch.max(hms.view(hms.size(0), hms.size(1), -1), 2)

    maxval = maxval.view(hms.size(0), hms.size(1), 1)
    idx = idx.view(hms.size(0), hms.size(1), 1) + 1

    preds = idx.repeat(1, 1, 2).float()

    preds[:, :, 0] = (preds[:, :, 0] - 1) % hms.size(3)#计算关键点的x坐标
    preds[:, :, 1] = torch.floor((preds[:, :, 1] - 1) / hms.size(3))#计算关键点的y坐标

    pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
    preds *= pred_mask

    # Very simple post-processing step to improve performance at tight PCK thresholds
    for i in range(preds.size(0)):
        for j in range(preds.size(1)):
            hm = hms[i][j]
            pX, pY = int(round(float(preds[i][j][0]))), \
                     int(round(float(preds[i][j][1])))
            if 0 < pX < opt.outputResW - 1 and 0 < pY < opt.outputResH - 1:
                diff = torch.Tensor(
                    (hm[pY][pX + 1] - hm[pY][pX - 1], 
                     hm[pY + 1][pX] - hm[pY - 1][pX]))
                preds[i][j] += diff.sign() * 0.25
    preds += 0.2

    preds_tf = torch.zeros(preds.size())

    preds_tf = transformBoxInvert_batch(preds, pt1, pt2, inpH, inpW, resH, resW)

    return preds, preds_tf, maxval

该函数在dataloader.py中调用

为了方便代码的理解,依次输出了调试过程中各变量的值 

调用该函数时,传入的各变量值为:

hms.shape=torch.Size([2, 17, 80, 64])

hms=tensor([ [ [ [ 7.1406e-06, -8.3787e-06, -4.2975e-06,  ...,  6.0789e-06,
                            -1.1158e-05,  1.0070e-05],
                            ...,
                           [-1.2371e-05, -9.3881e-06,  3.6857e-06,  ..., -9.8551e-06,
                            -3.5456e-05,  9.0087e-06]],

                           ...,

                         [ [ 2.2507e-04,  1.6495e-04,  1.8526e-04,  ...,  2.1645e-04,
                           1.6137e-04,  2.3066e-04],
                            ...,
                           [ 1.6770e-04,  1.9034e-04,  2.2977e-04,  ...,  1.8730e-04,
                             9.3242e-05,  2.4771e-04]]],
                       [ [ [-1.4953e-04, -5.0633e-07,  9.2041e-09,  ...,  1.0909e-05,
                            7.6110e-06, -6.9278e-05],
                            ...,
                           [-1.4777e-05, -6.0838e-06,  6.0697e-06,  ..., -7.9370e-06,
                            -3.8374e-05,  7.5062e-06]],

                            ...,

                          [ [ 2.3423e-04,  1.9299e-04,  2.0025e-04,  ...,  2.3365e-04,
                              2.3001e-04,  5.2406e-04],
                               ...,
                             [ 1.5949e-04,  2.0223e-04,  4.0982e-04,  ...,  1.9344e-04,
                               8.3786e-05,  2.4526e-04]]]])

pt1=tensor([ [344.1527, 172.4882], [577.4262, 198.1156] ])
pt2=tensor([ [566.1245, 650.4148], [719.0000, 654.6444] ])
inpH=320
inpW=256
resH=80
resW=64

assert hms.dim() == 4, 'Score maps should be 4-dim'
maxval, idx = torch.max(hms.view(hms.size(0), hms.size(1), -1), 2)

hms.size(0)=2
hms.size(1)=17
maxval为每张热力图中的最大值
maxval=tensor([ [0.6315, 0.6346, 0.4302, 0.7716, 0.8475, 0.7838, 0.8432, 0.7915, 0.7231,
                            0.8305, 0.8499, 0.6293, 0.5305, 0.4960, 0.6881, 0.7562, 0.5072], 
                           [0.8412, 0.7632, 0.7683, 0.6953, 0.3785, 0.5764, 0.6929, 0.4772, 0.4548,
                            0.6735, 0.5426, 0.3364, 0.2682, 0.2762, 0.2131, 0.5713, 0.5388] ])
idx为每张热力图中最大值对应的索引
idx=tensor([ [ 603,  604,  540,  604,  610,  987,  996, 1881, 1704, 2325, 2219, 2396,
                       2402, 3356, 3294, 4255, 4191],
                      [ 615,  553,  551,  555,  617, 1129, 1128, 1959, 1893, 2143, 2207, 2793,
                       2727, 4077, 4066, 4445, 4445] ])

maxval = maxval.view(hms.size(0), hms.size(1), 1)
idx = idx.view(hms.size(0), hms.size(1), 1) + 1

maxval=tensor([ [ [0.6315], [0.6346], [0.4302], [0.7716], [0.8475], [0.7838], [0.8432], [0.7915], 
                              [0.7231],  [0.8305], [0.8499], [0.6293], [0.5305], [0.4960], [0.6881],  [0.7562], 
                              [0.5072] ],
                            [ [0.8412], [0.7632], [0.7683], [0.6953], [0.3785], [0.5764],  [0.6929], [0.4772], 
                              [0.4548], [0.6735], [0.5426], [0.3364], [0.2682], [0.2762], [0.2131], [0.5713], 
                              [0.5388] ] ])

这里的idx为什么要+1呢
idx=tensor([ [ [ 604], [ 605], [ 541], [ 605], [ 611], [ 988], [ 997], [1882], [1705], [2326], [2220], 
                       [2397], [2403], [3357], [3295], [4256], [4192] ],
                     [ [ 616], [ 554], [ 552], [ 556], [ 618], [1130], [1129], [1960], [1894], [2144], [2208],
                       [2794],  [2728],  [4078],  [4067],  [4446],  [4446] ] ])

preds = idx.repeat(1, 1, 2).float()

preds = tensor([ [ [ 604.,  604.], [ 605.,  605.], [ 541.,  541.], [ 605.,  605.], [ 611.,  611.],
                             [988.,988.],  [ 997.,  997.], [1882., 1882.], [1705., 1705.], [2326., 2326.], 
                             [2220.,2220.], [2397., 2397.], [2403., 2403.], [3357., 3357.], [3295., 3295.], 
                             [4256., 4256.], [4192., 4192.]],

                           [ [ 616.,  616.], [ 554.,  554.],[ 552.,  552.], [556.,556.], [618.,618.], [1130.,1130.],
                             [1129., 1129.], [1960., 1960.],  [1894., 1894.], [2144., 2144.], [2208.,2208.],
                             [2794., 2794.], [2728., 2728.], [4078., 4078.], [4067., 4067.], [4446.,4446.], 
                             [4446., 4446.]]])

preds[:, :, 0] = (preds[:, :, 0] - 1) % hms.size(3)

hms.size(3)=64
获取热力图中最大值的x值坐标

preds=tensor([ [ [  27.,  604.], [  28.,  605.], [  28.,  541.], [  28.,  605.], [  34.,  611.], [  27.,  988.], 
                           [ 36.,  997.], [  25., 1882.], [  40., 1705.], [ 21., 2326.], [  43., 2220.], [  28.,2397.],
                           [ 34., 2403.], [  28., 3357.], [  30., 3295.], [  31., 4256.], [  31., 4192.]],

                         [ [  39.,  616.], [  41.,  554.], [  39.,  552.], [  43.,  556.], [  41.,  618.], [  41., 1130.], 
                           [40., 1129.], [ 39., 1960.], [  37., 1894.], [ 31., 2144.], [  31., 2208.], [  41.,2794.],
                           [ 39., 2728.], [  45., 4078.], [  34., 4067.], [  29., 4446.], [  29., 4446.]]])

preds[:, :, 1] = torch.floor((preds[:, :, 1] - 1) / hms.size(3))#计算关键点的y坐标

获取热力图中最大值的y值坐标

preds=tensor([ [ [27.,  9.], [28.,  9.], [28.,  8.], [28.,  9.], [34.,  9.], [27., 15.], [36., 15.], [25.,29.],
                           [40., 26.], [21., 36.], [43., 34.], [28., 37.], [34., 37.], [28., 52.], [30., 51.],
                           [31.,66.], [31., 65.]],

                         [ [39.,  9.], [41.,  8.], [39.,  8.], [43.,  8.], [41.,  9.], [41., 17.], [40., 17.], [39.,30.],
                           [37., 29.], [31., 33.], [31., 34.], [41., 43.], [39., 42.], [45., 63.], [34., 63.],
                           [29.,69.], [29., 69.]]])

pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
preds *= pred_mask

pred_mask =tensor([ [ [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], 
                                      [1.,1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.] ],

                                   [ [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], 
                                     [1.,1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.] ] ])

preds = tensor([ [ [27.,  9.], [28.,  9.], [28.,  8.], [28.,  9.], [34.,  9.], [27., 15.], [36., 15.], [25.,29.],
                             [40., 26.], [21., 36.], [43., 34.], [28., 37.], [34., 37.], [28., 52.], [30., 51.],
                             [31.,66.], [31., 65.]],

                           [ [39.,  9.], [41.,  8.], [39.,  8.], [43.,  8.], [41.,  9.], [41., 17.], [40., 17.], [39., 30.], 
                             [37., 29.], [31., 33.], [31., 34.], [41., 43.], [39., 42.], [45., 63.], [34., 63.],
                             [29.,69.], [29., 69.]]])

for i in range(preds.size(0)):
    for j in range(preds.size(1)):
        hm = hms[i][j]
        pX, pY = int(round(float(preds[i][j][0]))), int(round(float(preds[i][j][1])))
        if 0 < pX < opt.outputResW - 1 and 0 < pY < opt.outputResH - 1:
            diff = torch.Tensor(
                (hm[pY][pX + 1] - hm[pY][pX - 1],hm[pY + 1][pX] - hm[pY - 1][pX]))
            preds[i][j] += diff.sign() * 0.25

进入for循环后,i=j=0时
hm = hms[i][j]=tensor([ [ 7.1406e-06, -8.3787e-06, -4.2975e-06,  ...,  6.0789e-06,
                                      -1.1158e-05,  1.0070e-05]
                                     ...,
                                     [-1.2371e-05, -9.3881e-06,  3.6857e-06,  ..., -9.8551e-06,
                                      -3.5456e-05,  9.0087e-06]])

hm为一张热力图
round() 方法返回浮点数x的四舍五入值
preds[i][j][0]取关键点的x坐标,preds[i][j][1]取关键点的y坐标
pX=27, pY=9
diff=tensor([ 0.0744, -0.0159])

preds=tensor([ [ [27.2500,  8.7500], [28.0000,  9.0000], [28.0000,8.0000],  [28.0000,9.0000],
                           [34.0000,  9.0000],  [27.0000, 15.0000],  [36.0000,15.0000], [25.0000,29.0000],
                           [40.0000,26.0000],  [21.0000, 36.0000], [43.0000,34.0000], [28.0000,37.0000],
                           [34.0000, 37.0000], [28.0000, 52.0000], [30.0000, 51.0000], [31.0000, 66.0000],
                           [31.0000, 65.0000] ],

                         [ [39.0000,  9.0000], [41.0000,  8.0000], [39.0000,  8.0000], [43.0000,8.0000],
                           [41.0000,9.0000], [41.0000,17.0000], [40.0000,17.0000], [39.0000,30.0000],
                           [37.0000,29.0000],  [31.0000, 33.0000], [31.0000, 34.0000], [41.0000,43.0000],
                           [39.0000, 42.0000], [45.0000, 63.0000], [34.0000, 63.0000], [29.0000, 69.0000],
                           [29.0000, 69.0000] ] ])

preds += 0.2
preds_tf = torch.zeros(preds.size())
preds_tf = transformBoxInvert_batch(preds, pt1, pt2, inpH, inpW, resH, resW)

preds加2之后

preds=tensor([ [ [27.4500,  8.9500],[27.9500,  8.9500],[27.9500,  7.9500],[28.4500,  9.4500],
                           [33.9500,  8.9500],[27.4500, 15.4500],[35.9500, 14.9500],[25.4500, 28.9500],
                           [39.9500, 25.9500],[21.4500, 35.9500],[42.9500, 33.9500],[27.9500, 36.9500],
                           [33.9500, 37.4500],[27.9500, 51.9500],[29.9500, 51.4500],[31.4500, 66.4500],
                           [30.9500, 65.4500]],

                         [ [39.4500,  9.4500],[40.9500,  7.9500],[39.4500,  7.9500],[42.9500,  7.9500],
                           [40.9500,  8.9500],[41.4500, 16.9500],[39.9500, 17.4500],[39.4500, 29.9500],
                           [37.4500, 28.9500],[31.4500, 33.4500],[31.4500, 33.9500], [41.4500, 42.9500],
                           [39.4500, 41.9500],[44.9500, 63.4500],[33.9500, 63.4500],[29.4500, 69.4500], 
                           [29.4500, 69.4500]]])

这里调用了transformBoxInvert_batch()函数


进入transformBoxInvert_batch函数

def transformBoxInvert_batch(pt, ul, br, inpH, inpW, resH, resW):

pt的值为上面的preds
ul=tensor([ [344.1527, 172.4882],
                  [577.4262, 198.1156] ]) 为当前人员目标框在原图像中的左上角坐标
br=tensor([ [566.1245, 650.4148],
                   [719.0000, 654.6444] ]) 为当前人员目标框在原图像中的右下角坐标
inpH=320
inpW=256
resH=80
resW=64

center = (br - 1 - ul) / 2
size = br - ul #计算出原图像中人员目标框的宽高
size[:, 0] *= (inpH / inpW)

center = (br - 1 - ul) / 2=tensor([ [110.4859, 238.4633],
                                                    [ 70.2869, 227.7644] ])
size = br - ul=tensor([ [221.9717, 477.9266],
                                   [141.5738, 456.5288]])
size[:, 0] *= (inpH / inpW)=tensor([ [277.4646, 477.9266],
                                                        [176.9672, 456.5288] ])

lenH, _ = torch.max(size, dim=1)   # [n,]
lenW = lenH * (inpW / inpH)
_pt = (pt * lenH[:, np.newaxis, np.newaxis]) / resH #映射到原图中人员目标框图像上

将热力图上的坐标点映射到原图中的人员目标框图像上,注意该人员目标框是以目标框左上角为坐标原点,并不是将关键点坐标映射到原图上
lenH=tensor([477.9266, 456.5288]),   _ = tensor([1, 1])
lenW = tensor([382.3413, 365.2231])
np.newaxis=None
_pt = tensor([ [ [163.9886,  53.4680],
                         ...,
                         [184.8979, 391.0037]],

                      [  [225.1258,  53.9275],
                         ...,
                         [168.0597, 396.3240]]])

_pt[:, :, 0] = _pt[:, :, 0] - ((lenW[:, np.newaxis].repeat(1, 17) - 1) /
              2 - center[:, 0].unsqueeze(-1).repeat(1, 17)).clamp(min=0)
_pt[:, :, 1] = _pt[:, :, 1] - ((lenH[:, np.newaxis].repeat(1, 17) - 1) /
              2 - center[:, 1].unsqueeze(-1).repeat(1, 17)).clamp(min=0)

对上面的lenW和这一步的代码作用不理解!
此时_pt=tensor([ [ [ 83.8038,  53.4680],
                              ...,
                              [104.7131, 391.0037]],

                            [ [113.3011,  53.9275],
                              ...,
                              [ 56.2350, 396.3240]]])

new_point = torch.zeros(pt.size())
new_point[:, :, 0] = _pt[:, :, 0] + ul[:, 0].unsqueeze(-1).repeat(1, 17)
new_point[:, :, 1] = _pt[:, :, 1] + ul[:, 1].unsqueeze(-1).repeat(1, 17)

将坐标点加上人员目标框左上角在原图中的x,y坐标,得到关键点在原图中的坐标
new_point.shape = torch.Size([2, 17, 2])
new_point = tensor([ [ [427.9565, 225.9562],
                                    ...,
                                    [448.8658, 563.4919]],

                                  [ [690.7274, 252.0431],
                                    ...,
                                    [633.6613, 594.4396]]])

return new_point

最后将关键点返回


调用完transformBoxInvert_batch()函数之后

preds_tf=tensor([ [ [427.9565, 225.9562],[430.9435, 225.9562],[430.9435, 219.9821],
                                [433.9306,228.9433],[466.7881, 225.9562],[427.9565,264.7878], 
                                [478.7362,261.8007],  [416.0084, 345.4379],[502.6326, 327.5156],
                                [392.1120,387.2565], [520.5548, 375.3083], [430.9435, 393.2305],
                                [466.7881, 396.2176], [430.9435, 482.8418], [442.8917,479.8547],
                                [451.8528, 569.4659], [448.8658, 563.4919]],

                              [ [690.7274, 252.0431],[699.2873, 243.4832],[690.7274, 243.4832],
                                [710.7005,243.4832], [699.2873, 249.1898],[702.1406, 294.8427],
                                [693.5807,297.6960],[690.7274, 369.0286],[679.3141, 363.3220],
                                [645.0745,389.0017],[645.0745, 391.8550],[702.1406, 443.2145],
                                [690.7274,437.5079], [722.1137, 560.2000],[659.3410, 560.2000],
                                [633.6613, 594.4396], [633.6613, 594.4396] ] ])

return preds, preds_tf, maxval

最后将(预测骨架点在热力图中的位置,预测骨架点在原图中的位置,每张热力图的最大值)返回

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值