如谷歌搜索结果所述,这种情况多半是loss计算出了点问题
但这个报错烦人的点是它不告诉你哪有问题,而且报错的行往往不是真正有问题的行,很多时候是有问题行的下一行。报错的变量也不是真正有问题的变量,而是显存里之前的变量
debug方法就是逐行插入print看在哪tensor的值输出不出来,问题就定位到了。笔者的问题最后发现是多分类问题最大类别数设成了8,但logit tensor是20维的,它输出一个大于8的argmax(比如14)类就不对了。
如谷歌搜索结果所述,这种情况多半是loss计算出了点问题
但这个报错烦人的点是它不告诉你哪有问题,而且报错的行往往不是真正有问题的行,很多时候是有问题行的下一行。报错的变量也不是真正有问题的变量,而是显存里之前的变量
debug方法就是逐行插入print看在哪tensor的值输出不出来,问题就定位到了。笔者的问题最后发现是多分类问题最大类别数设成了8,但logit tensor是20维的,它输出一个大于8的argmax(比如14)类就不对了。