这里必须吐槽下中文环境的bug搜寻情况,直接输入上面的报错很难找到中文描述对应的解决方案。
这是在使用pytorch 自带的BCELoss所报的错误,在GPU,多GPU与CPU运行loss时都会报这样的错误。
我的pytorch环境是0.4.1 ,低于此版本的pytorch都会报这样的问题,1.0版本以上的没有测试过,不知道。
这个bug的来源是BCELoss中输入的张量value的范围必须在[0.0,1.0]之类,而有时候模型的输出是超出这个范畴的,因此BCELoss会报错。解决方案很简单:
assert (inputs > 0.0 & inputs <1.0).all()
使用这个方案有时候会因为cpu,gpu的问题还会报错,有时候还会因为detach的问题报错;但是都没关系,简单而粗暴的解决上述bug的方案:
inputs[inputs < 0.0] = 0.0
inputs[inputs > 1.0] = 1.0
这样上面所有的问题可以解决