跑一个down下来的模型,根据提示训练代码时报错:
../aten/src/ATen/native/cuda/Loss.cu:95: operator(): block: [6,0,0], thread: [29,0,0] Assertion `target_val >= zero && target_val <= one` failed.
具体的函数定位到了这里:
File "main.py", line 285, in train
loss += cross_entropy_loss_RCF(o, label, args.lmbda)
在loss计算时出现了问题,debug到了该函数的这一行:
cost = F.binary_cross_entropy(
prediction, labelf, weight=mask, reduction='sum')
结论:函数 F.binary_cross_entropy() 由于pytorch版本不兼容造成了上述问题。
根据GPT的说法:
在 PyTorch 1.0.0 版本中,
F.binary_cross_entropy
函数的参数labelf
张量代表真实的二进制标签。对于该版本,target
张量的元素值可以是任意的,不受特定取值要求的限制。
但是在1.1.0版本中,对于输入的张量元素值限制到了区间 [0,1] 之间。我跑的这份代码 labelf 中存在元素值为2,故此处多了断言的报错。