解决RuntimeError: _thnn_mse_loss_forward is not implemented for type torch.cuda.LongTensor和scatter_方法

在PyTorch中遇到了如标题的问题,我使用的MSE损失函数,网上大多数给的是类型不匹配问题,在stackoverflow找到了问题的答案,这里出现的问题是因为loss需要one-hot类型的数据,而我们使用的是类别标签。

什么是one-hot?

一个例子解释什么是one-hot,对于5分类问题,我们使用[0,0,1,0,0]来表示这个实例是属于第三个类别的,等价于类别标签[2](从0对类别编码)。关于one-hot的好处,自行百度或google。

解决办法和scatter_函数介绍

我们需要将神经网络的预测out和实例本身的标签label变为one-hot形式,因为out和label里面存储的是最大值索引,所以变换依赖于Tensor对象的scatter_方法,在索引位置设为1,其他为0,关于方法介绍查看下面的链接
Pytorch 学习(5):Pytorch中的 torch.gather/scatter_ 聚集/分散操作
官方文档
关于如何使用该方法实现one-hot,见如下链接
转换为one-hot格式

在这里插入图片描述
scatter方法就是将src中的值按照index张量中的索引赋给当前张量相应位置的值,这里需要注意一件事情是scatter_方法要求
can be either empty or the same size of src. When empty, the operation returns identity,即index数组必须与src数组的维度一致,经过实验它还必须与源数组一样,原数组为2维,那么下标数组必须为2维(你可以这样想,如果源为二维,index数组为一维,假设dim为1,那么我们就不能确定那些行需要变换。其中每一维的数量可以不相等,但有相应约束,可看官方文档)。其中src可以为某一个浮点数,我估计它内部是进行了广播机制的,将这个浮点数扩展为输出数组的维度。综上,我们需要将上面的out和label数组变为二维(因为one-hot是二维的),可以调用out=out.reshape(种类数,1)。(PS:在pytorch中,一维是只输入一个int型就行了,只要输入了两个int数,那么一定是二维的,所以“(种类数,1)”得到二维张量)
在torch中,一个函数后面加上_符号,表示对自己作用,不加表示返回值为作用结果,而自身不改变。
这样就解决这个问题了。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值