省流版:优先排查模型最后一层,是不是输出维度小于数据集中的标签数
比如模型最后一层是(512, 10)即⑩分类,而数据集中最大标签大于等于10了(理论上最大为9)
这里可以采用targets.max()的方法来快速检查
报错信息如下所示:
错误出在loss.backward()这一句上,已经验证过predict和output的维度正确,loss也能正常算出来,因为这句平常用的也比较多不应该是实现问题,优先从自己的模型结构和代码上找问题。
最后发现是加载模型的时候忘记修改输出层维度了,对于数据集中20类目标,模型输出层只有11维,修改后问题解决。