Assertion `target_val >= zero && target_val <= one` failed

文章描述了在使用PyTorch训练模型时遇到的错误,源于不同版本对`F.binary_cross_entropy`函数的target张量值限制不同。1.0.0版本接受任意值,而1.1.0及后续版本要求target在[0,1]范围内,导致labelf中的2值触发了断言失败。
摘要由CSDN通过智能技术生成

跑一个down下来的模型,根据提示训练代码时报错:

../aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [6,0,0], thread: [29,0,0] Assertion `target_val >= zero && target_val <= one` failed.

具体的函数定位到了这里:

  File "main.py", line 285, in train
    loss += cross_entropy_loss_RCF(o, label, args.lmbda)

在loss计算时出现了问题,debug到了该函数的这一行:

    cost = F.binary_cross_entropy(
            prediction, labelf, weight=mask, reduction='sum')

结论:函数  F.binary_cross_entropy()  由于pytorch版本不兼容造成了上述问题。

根据GPT的说法:

在 PyTorch 1.0.0 版本中,F.binary_cross_entropy 函数的参数 labelf 张量代表真实的二进制标签。对于该版本,target 张量的元素值可以是任意的,不受特定取值要求的限制。

但是在1.1.0版本中,对于输入的张量元素值限制到了区间 [0,1] 之间。我跑的这份代码 labelf 中存在元素值为2,故此处多了断言的报错。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值