在训练模型过程中有如下语句,结果出现报错
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
RuntimeError: gather_out_cuda(): Expected dtype int64 for index
原因应该是和pytorch版本有关系,根据报错,找到原因为target.unsqueeze(2)的dtype是int32类型的,把它转为int64了,可以正常运行。
在训练模型过程中有如下语句,结果出现报错
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
RuntimeError: gather_out_cuda(): Expected dtype int64 for index
原因应该是和pytorch版本有关系,根据报错,找到原因为target.unsqueeze(2)的dtype是int32类型的,把它转为int64了,可以正常运行。