pytorch loss = loss_func(output, label) 报错

代码报错问题 专栏收录该内容
3 篇文章 0 订阅

pytorch 实现手写识别运行交叉熵loss时loss_fn = nn.CrossEntropyLoss(outputs, y_label)出现错误:
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 ‘other’
原因:CrossEntropyLoss()使用one-hot自动编码,但是label为int型,要改为torch.int64才能被编码为long型。
解决方案:
y_label =torch.from_numpy(y_label).long()

  • 0
    点赞
  • 0
    评论
  • 3
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2021 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值