soft-argmax踩坑

最近在2D human pose estimation时需要用到soft-argmax,找了几个版本的函数,都有一个问题

RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Long'

一、代码如下

def softargmax2d(input, beta=100):
    *_, h, w = input.shape

    input = beta*input.reshape(*_, h * w)
    input = F.softmax( input, dim=-1)

    indices_c, indices_r = np.meshgrid(
        np.linspace(0, 1, w),
        np.linspace(0, 1, h),
        indexing='xy'
    )

    indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w)))
    indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w)))

    result_r = torch.sum((h - 1) * input * indices_r, dim=-1)
    result_c = torch.sum((w - 1) * input * indices_c, dim=-1)

    result = torch.stack([result_r, result_c], dim=-1)

    return result

二、测试如下

c=[[[[1,2,3],[4,5,16],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]]]
c=torch.tensor(c)
print(c.size())
b=softargmax2d(c)
print(b)

三、结果如下

Traceback (most recent call last):
torch.Size([1, 2, 3, 3])
  File "F:/pythonprogram/mon_repnet/wrm_model.py", line 202, in <module>
    b=softargmax2d(c)
  File "F:/pythonprogram/mon_repnet/wrm_model.py", line 172, in softargmax2d
    input = F.softmax( input, dim=-1)
  File "E:\software\python36\lib\site-packages\torch\nn\functional.py", line 1231, in softmax
    ret = input.softmax(dim)
RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Long'

找了很久你会发现很难搜到解决办法,其实……只要

四、修正如下

c=[[[[1,2,3],[4,5,16],[7,8,9]],[[1,2,3],[4,5,6],[7,8,9]]]]
c=torch.tensor(c).float()
print(c.size())
b=softargmax2d(c)
print(b)

转换输入的类型为float即可,额~~~~~~~

五、最终结果如下

torch.Size([1, 2, 3, 3])
tensor([[[1., 2.],
         [2., 2.]]], dtype=torch.float64)

答案正确,最大值坐标分别为(1,2),(2,2)

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值